diff mbox series

[net-next,v19,09/13] tcp: RX path for devmem TCP

Message ID 20240813211317.3381180-10-almasrymina@google.com
State Superseded
Headers show
Series [net-next,v19,01/13] netdev: add netdev_rx_queue_restart() | expand

Commit Message

Mina Almasry Aug. 13, 2024, 9:13 p.m. UTC
In tcp_recvmsg_locked(), detect if the skb being received by the user
is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
flag - pass it to tcp_recvmsg_devmem() for custom handling.

tcp_recvmsg_devmem() copies any data in the skb header to the linear
buffer, and returns a cmsg to the user indicating the number of bytes
returned in the linear buffer.

tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
and returns to the user a cmsg_devmem indicating the location of the
data in the dmabuf device memory. cmsg_devmem contains this information:

1. the offset into the dmabuf where the payload starts. 'frag_offset'.
2. the size of the frag. 'frag_size'.
3. an opaque token 'frag_token' to return to the kernel when the buffer
is to be released.

The pages awaiting freeing are stored in the newly added
sk->sk_user_frags, and each page passed to userspace is get_page()'d.
This reference is dropped once the userspace indicates that it is
done reading this page.  All pages are released when the socket is
destroyed.

Signed-off-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Kaiyuan Zhang <kaiyuanz@google.com>
Signed-off-by: Mina Almasry <almasrymina@google.com>
Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>

---

v16:
- Fix number assignement (Arnd).

v13:
- Refactored user frags cleanup into a common function to avoid
  __maybe_unused. (Pavel)
- change to offset = 0 for some improved clarity.

v11:
- Refactor to common function te remove conditional lock sparse warning
  (Paolo)

v7:
- Updated the SO_DEVMEM_* uapi to use the next available entries (Arnd).
- Updated dmabuf_cmsg struct to be __u64 padded (Arnd).
- Squashed fix from Eric to initialize sk_user_frags for passive
  sockets (Eric).

v6
- skb->dmabuf -> skb->readable (Pavel)
- Fixed asm definitions of SO_DEVMEM_LINEAR/SO_DEVMEM_DMABUF not found
  on some archs.
- Squashed in locking optimizations from edumazet@google.com. With this
  change we lock the xarray once per per tcp_recvmsg_dmabuf() rather
  than once per frag in xa_alloc().

Changes in v1:
- Added dmabuf_id to dmabuf_cmsg (David/Stan).
- Devmem -> dmabuf (David).
- Change tcp_recvmsg_dmabuf() check to skb->dmabuf (Paolo).
- Use __skb_frag_ref() & napi_pp_put_page() for refcounting (Yunsheng).

RFC v3:
- Fixed issue with put_cmsg() failing silently.

---
 arch/alpha/include/uapi/asm/socket.h  |   5 +
 arch/mips/include/uapi/asm/socket.h   |   5 +
 arch/parisc/include/uapi/asm/socket.h |   5 +
 arch/sparc/include/uapi/asm/socket.h  |   5 +
 include/linux/socket.h                |   1 +
 include/net/netmem.h                  |  13 ++
 include/net/sock.h                    |   2 +
 include/uapi/asm-generic/socket.h     |   5 +
 include/uapi/linux/uio.h              |  13 ++
 net/ipv4/tcp.c                        | 255 +++++++++++++++++++++++++-
 net/ipv4/tcp_ipv4.c                   |  16 ++
 net/ipv4/tcp_minisocks.c              |   2 +
 12 files changed, 322 insertions(+), 5 deletions(-)

Comments

Taehee Yoo Aug. 17, 2024, 1:58 p.m. UTC | #1
On Wed, Aug 14, 2024 at 6:13 AM Mina Almasry <almasrymina@google.com> wrote:
>

Hi Mina,

> In tcp_recvmsg_locked(), detect if the skb being received by the user
> is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
> flag - pass it to tcp_recvmsg_devmem() for custom handling.
>
> tcp_recvmsg_devmem() copies any data in the skb header to the linear
> buffer, and returns a cmsg to the user indicating the number of bytes
> returned in the linear buffer.
>
> tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
> and returns to the user a cmsg_devmem indicating the location of the
> data in the dmabuf device memory. cmsg_devmem contains this information:
>
> 1. the offset into the dmabuf where the payload starts. 'frag_offset'.

I have been testing this patch and I found a bug.
While testing it with the ncdevmem cmd, it fails to validate buffers
after some period.
This is because tcp_recvmsg_dmabuf() can't handle skb properly when
the parameter offset != 0.
The tcp_recvmsg_dmabuf() already has the code that handles skb if
offset is not 0 but it doesn't work for a specific case.

> 2. the size of the frag. 'frag_size'.
> 3. an opaque token 'frag_token' to return to the kernel when the buffer
> is to be released.
>
> The pages awaiting freeing are stored in the newly added
> sk->sk_user_frags, and each page passed to userspace is get_page()'d.
> This reference is dropped once the userspace indicates that it is
> done reading this page.  All pages are released when the socket is
> destroyed.
>
> Signed-off-by: Willem de Bruijn <willemb@google.com>
> Signed-off-by: Kaiyuan Zhang <kaiyuanz@google.com>
> Signed-off-by: Mina Almasry <almasrymina@google.com>
> Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
> Reviewed-by: Eric Dumazet <edumazet@google.com>
>
> ---
>
> v16:
> - Fix number assignement (Arnd).
>
> v13:
> - Refactored user frags cleanup into a common function to avoid
>   __maybe_unused. (Pavel)
> - change to offset = 0 for some improved clarity.
>
> v11:
> - Refactor to common function te remove conditional lock sparse warning
>   (Paolo)
>
> v7:
> - Updated the SO_DEVMEM_* uapi to use the next available entries (Arnd).
> - Updated dmabuf_cmsg struct to be __u64 padded (Arnd).
> - Squashed fix from Eric to initialize sk_user_frags for passive
>   sockets (Eric).
>
> v6
> - skb->dmabuf -> skb->readable (Pavel)
> - Fixed asm definitions of SO_DEVMEM_LINEAR/SO_DEVMEM_DMABUF not found
>   on some archs.
> - Squashed in locking optimizations from edumazet@google.com. With this
>   change we lock the xarray once per per tcp_recvmsg_dmabuf() rather
>   than once per frag in xa_alloc().
>
> Changes in v1:
> - Added dmabuf_id to dmabuf_cmsg (David/Stan).
> - Devmem -> dmabuf (David).
> - Change tcp_recvmsg_dmabuf() check to skb->dmabuf (Paolo).
> - Use __skb_frag_ref() & napi_pp_put_page() for refcounting (Yunsheng).
>
> RFC v3:
> - Fixed issue with put_cmsg() failing silently.
>
> ---
>  arch/alpha/include/uapi/asm/socket.h  |   5 +
>  arch/mips/include/uapi/asm/socket.h   |   5 +
>  arch/parisc/include/uapi/asm/socket.h |   5 +
>  arch/sparc/include/uapi/asm/socket.h  |   5 +
>  include/linux/socket.h                |   1 +
>  include/net/netmem.h                  |  13 ++
>  include/net/sock.h                    |   2 +
>  include/uapi/asm-generic/socket.h     |   5 +
>  include/uapi/linux/uio.h              |  13 ++
>  net/ipv4/tcp.c                        | 255 +++++++++++++++++++++++++-
>  net/ipv4/tcp_ipv4.c                   |  16 ++
>  net/ipv4/tcp_minisocks.c              |   2 +
>  12 files changed, 322 insertions(+), 5 deletions(-)
>
> diff --git a/arch/alpha/include/uapi/asm/socket.h b/arch/alpha/include/uapi/asm/socket.h
> index e94f621903fe..ef4656a41058 100644
> --- a/arch/alpha/include/uapi/asm/socket.h
> +++ b/arch/alpha/include/uapi/asm/socket.h
> @@ -140,6 +140,11 @@
>  #define SO_PASSPIDFD           76
>  #define SO_PEERPIDFD           77
>
> +#define SO_DEVMEM_LINEAR       78
> +#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
> +#define SO_DEVMEM_DMABUF       79
> +#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
> +
>  #if !defined(__KERNEL__)
>
>  #if __BITS_PER_LONG == 64
> diff --git a/arch/mips/include/uapi/asm/socket.h b/arch/mips/include/uapi/asm/socket.h
> index 60ebaed28a4c..414807d55e33 100644
> --- a/arch/mips/include/uapi/asm/socket.h
> +++ b/arch/mips/include/uapi/asm/socket.h
> @@ -151,6 +151,11 @@
>  #define SO_PASSPIDFD           76
>  #define SO_PEERPIDFD           77
>
> +#define SO_DEVMEM_LINEAR       78
> +#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
> +#define SO_DEVMEM_DMABUF       79
> +#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
> +
>  #if !defined(__KERNEL__)
>
>  #if __BITS_PER_LONG == 64
> diff --git a/arch/parisc/include/uapi/asm/socket.h b/arch/parisc/include/uapi/asm/socket.h
> index be264c2b1a11..2b817efd4544 100644
> --- a/arch/parisc/include/uapi/asm/socket.h
> +++ b/arch/parisc/include/uapi/asm/socket.h
> @@ -132,6 +132,11 @@
>  #define SO_PASSPIDFD           0x404A
>  #define SO_PEERPIDFD           0x404B
>
> +#define SO_DEVMEM_LINEAR       78
> +#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
> +#define SO_DEVMEM_DMABUF       79
> +#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
> +
>  #if !defined(__KERNEL__)
>
>  #if __BITS_PER_LONG == 64
> diff --git a/arch/sparc/include/uapi/asm/socket.h b/arch/sparc/include/uapi/asm/socket.h
> index 682da3714686..00248fc68977 100644
> --- a/arch/sparc/include/uapi/asm/socket.h
> +++ b/arch/sparc/include/uapi/asm/socket.h
> @@ -133,6 +133,11 @@
>  #define SO_PASSPIDFD             0x0055
>  #define SO_PEERPIDFD             0x0056
>
> +#define SO_DEVMEM_LINEAR         0x0057
> +#define SCM_DEVMEM_LINEAR        SO_DEVMEM_LINEAR
> +#define SO_DEVMEM_DMABUF         0x0058
> +#define SCM_DEVMEM_DMABUF        SO_DEVMEM_DMABUF
> +
>  #if !defined(__KERNEL__)
>
>
> diff --git a/include/linux/socket.h b/include/linux/socket.h
> index df9cdb8bbfb8..d18cc47e89bd 100644
> --- a/include/linux/socket.h
> +++ b/include/linux/socket.h
> @@ -327,6 +327,7 @@ struct ucred {
>                                           * plain text and require encryption
>                                           */
>
> +#define MSG_SOCK_DEVMEM 0x2000000      /* Receive devmem skbs as cmsg */
>  #define MSG_ZEROCOPY   0x4000000       /* Use user data in kernel path */
>  #define MSG_SPLICE_PAGES 0x8000000     /* Splice the pages from the iterator in sendmsg() */
>  #define MSG_FASTOPEN   0x20000000      /* Send data in TCP SYN */
> diff --git a/include/net/netmem.h b/include/net/netmem.h
> index 284f84a312c2..84043fbdd797 100644
> --- a/include/net/netmem.h
> +++ b/include/net/netmem.h
> @@ -65,6 +65,19 @@ static inline unsigned int net_iov_idx(const struct net_iov *niov)
>         return niov - net_iov_owner(niov)->niovs;
>  }
>
> +static inline unsigned long net_iov_virtual_addr(const struct net_iov *niov)
> +{
> +       struct dmabuf_genpool_chunk_owner *owner = net_iov_owner(niov);
> +
> +       return owner->base_virtual +
> +              ((unsigned long)net_iov_idx(niov) << PAGE_SHIFT);
> +}
> +
> +static inline u32 net_iov_binding_id(const struct net_iov *niov)
> +{
> +       return net_iov_owner(niov)->binding->id;
> +}
> +
>  static inline struct net_devmem_dmabuf_binding *
>  net_iov_binding(const struct net_iov *niov)
>  {
> diff --git a/include/net/sock.h b/include/net/sock.h
> index cce23ac4d514..f8ec869be238 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -337,6 +337,7 @@ struct sk_filter;
>    *    @sk_txtime_report_errors: set report errors mode for SO_TXTIME
>    *    @sk_txtime_unused: unused txtime flags
>    *    @ns_tracker: tracker for netns reference
> +  *    @sk_user_frags: xarray of pages the user is holding a reference on.
>    */
>  struct sock {
>         /*
> @@ -542,6 +543,7 @@ struct sock {
>  #endif
>         struct rcu_head         sk_rcu;
>         netns_tracker           ns_tracker;
> +       struct xarray           sk_user_frags;
>  };
>
>  struct sock_bh_locked {
> diff --git a/include/uapi/asm-generic/socket.h b/include/uapi/asm-generic/socket.h
> index 8ce8a39a1e5f..e993edc9c0ee 100644
> --- a/include/uapi/asm-generic/socket.h
> +++ b/include/uapi/asm-generic/socket.h
> @@ -135,6 +135,11 @@
>  #define SO_PASSPIDFD           76
>  #define SO_PEERPIDFD           77
>
> +#define SO_DEVMEM_LINEAR       78
> +#define SCM_DEVMEM_LINEAR      SO_DEVMEM_LINEAR
> +#define SO_DEVMEM_DMABUF       79
> +#define SCM_DEVMEM_DMABUF      SO_DEVMEM_DMABUF
> +
>  #if !defined(__KERNEL__)
>
>  #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
> diff --git a/include/uapi/linux/uio.h b/include/uapi/linux/uio.h
> index 059b1a9147f4..3a22ddae376a 100644
> --- a/include/uapi/linux/uio.h
> +++ b/include/uapi/linux/uio.h
> @@ -20,6 +20,19 @@ struct iovec
>         __kernel_size_t iov_len; /* Must be size_t (1003.1g) */
>  };
>
> +struct dmabuf_cmsg {
> +       __u64 frag_offset;      /* offset into the dmabuf where the frag starts.
> +                                */
> +       __u32 frag_size;        /* size of the frag. */
> +       __u32 frag_token;       /* token representing this frag for
> +                                * DEVMEM_DONTNEED.
> +                                */
> +       __u32  dmabuf_id;       /* dmabuf id this frag belongs to. */
> +       __u32 flags;            /* Currently unused. Reserved for future
> +                                * uses.
> +                                */
> +};
> +
>  /*
>   *     UIO_MAXIOV shall be at least 16 1003.1g (5.4.1.1)
>   */
> diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
> index 30e0aa38ba9b..40e7335dae6e 100644
> --- a/net/ipv4/tcp.c
> +++ b/net/ipv4/tcp.c
> @@ -471,6 +471,7 @@ void tcp_init_sock(struct sock *sk)
>
>         set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
>         sk_sockets_allocated_inc(sk);
> +       xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1);
>  }
>  EXPORT_SYMBOL(tcp_init_sock);
>
> @@ -2323,6 +2324,220 @@ static int tcp_inq_hint(struct sock *sk)
>         return inq;
>  }
>
> +/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */
> +struct tcp_xa_pool {
> +       u8              max; /* max <= MAX_SKB_FRAGS */
> +       u8              idx; /* idx <= max */
> +       __u32           tokens[MAX_SKB_FRAGS];
> +       netmem_ref      netmems[MAX_SKB_FRAGS];
> +};
> +
> +static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p)
> +{
> +       int i;
> +
> +       /* Commit part that has been copied to user space. */
> +       for (i = 0; i < p->idx; i++)
> +               __xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY,
> +                            (__force void *)p->netmems[i], GFP_KERNEL);
> +       /* Rollback what has been pre-allocated and is no longer needed. */
> +       for (; i < p->max; i++)
> +               __xa_erase(&sk->sk_user_frags, p->tokens[i]);
> +
> +       p->max = 0;
> +       p->idx = 0;
> +}
> +
> +static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p)
> +{
> +       if (!p->max)
> +               return;
> +
> +       xa_lock_bh(&sk->sk_user_frags);
> +
> +       tcp_xa_pool_commit_locked(sk, p);
> +
> +       xa_unlock_bh(&sk->sk_user_frags);
> +}
> +
> +static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p,
> +                             unsigned int max_frags)
> +{
> +       int err, k;
> +
> +       if (p->idx < p->max)
> +               return 0;
> +
> +       xa_lock_bh(&sk->sk_user_frags);
> +
> +       tcp_xa_pool_commit_locked(sk, p);
> +
> +       for (k = 0; k < max_frags; k++) {
> +               err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k],
> +                                XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL);
> +               if (err)
> +                       break;
> +       }
> +
> +       xa_unlock_bh(&sk->sk_user_frags);
> +
> +       p->max = k;
> +       p->idx = 0;
> +       return k ? 0 : err;
> +}
> +
> +/* On error, returns the -errno. On success, returns number of bytes sent to the
> + * user. May not consume all of @remaining_len.
> + */
> +static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb,
> +                             unsigned int offset, struct msghdr *msg,
> +                             int remaining_len)
> +{
> +       struct dmabuf_cmsg dmabuf_cmsg = { 0 };
> +       struct tcp_xa_pool tcp_xa_pool;
> +       unsigned int start;
> +       int i, copy, n;
> +       int sent = 0;
> +       int err = 0;
> +
> +       tcp_xa_pool.max = 0;
> +       tcp_xa_pool.idx = 0;
> +       do {
> +               start = skb_headlen(skb);
> +
> +               if (skb_frags_readable(skb)) {
> +                       err = -ENODEV;
> +                       goto out;
> +               }
> +
> +               /* Copy header. */
> +               copy = start - offset;
> +               if (copy > 0) {
> +                       copy = min(copy, remaining_len);
> +
> +                       n = copy_to_iter(skb->data + offset, copy,
> +                                        &msg->msg_iter);
> +                       if (n != copy) {
> +                               err = -EFAULT;
> +                               goto out;
> +                       }
> +
> +                       offset += copy;
> +                       remaining_len -= copy;
> +
> +                       /* First a dmabuf_cmsg for # bytes copied to user
> +                        * buffer.
> +                        */
> +                       memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg));
> +                       dmabuf_cmsg.frag_size = copy;
> +                       err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR,
> +                                      sizeof(dmabuf_cmsg), &dmabuf_cmsg);
> +                       if (err || msg->msg_flags & MSG_CTRUNC) {
> +                               msg->msg_flags &= ~MSG_CTRUNC;
> +                               if (!err)
> +                                       err = -ETOOSMALL;
> +                               goto out;
> +                       }
> +
> +                       sent += copy;
> +
> +                       if (remaining_len == 0)
> +                               goto out;
> +               }
> +
> +               /* after that, send information of dmabuf pages through a
> +                * sequence of cmsg
> +                */
> +               for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
> +                       skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
> +                       struct net_iov *niov;
> +                       u64 frag_offset;
> +                       int end;
> +
> +                       /* !skb_frags_readable() should indicate that ALL the
> +                        * frags in this skb are dmabuf net_iovs. We're checking
> +                        * for that flag above, but also check individual frags
> +                        * here. If the tcp stack is not setting
> +                        * skb_frags_readable() correctly, we still don't want
> +                        * to crash here.
> +                        */
> +                       if (!skb_frag_net_iov(frag)) {
> +                               net_err_ratelimited("Found non-dmabuf skb with net_iov");
> +                               err = -ENODEV;
> +                               goto out;
> +                       }
> +
> +                       niov = skb_frag_net_iov(frag);
> +                       end = start + skb_frag_size(frag);
> +                       copy = end - offset;
> +
> +                       if (copy > 0) {
> +                               copy = min(copy, remaining_len);
> +
> +                               frag_offset = net_iov_virtual_addr(niov) +
> +                                             skb_frag_off(frag) + offset -
> +                                             start;
> +                               dmabuf_cmsg.frag_offset = frag_offset;
> +                               dmabuf_cmsg.frag_size = copy;
> +                               err = tcp_xa_pool_refill(sk, &tcp_xa_pool,
> +                                                        skb_shinfo(skb)->nr_frags - i);
> +                               if (err)
> +                                       goto out;
> +
> +                               /* Will perform the exchange later */
> +                               dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx];
> +                               dmabuf_cmsg.dmabuf_id = net_iov_binding_id(niov);
> +
> +                               offset += copy;
> +                               remaining_len -= copy;
> +
> +                               err = put_cmsg(msg, SOL_SOCKET,
> +                                              SO_DEVMEM_DMABUF,
> +                                              sizeof(dmabuf_cmsg),
> +                                              &dmabuf_cmsg);
> +                               if (err || msg->msg_flags & MSG_CTRUNC) {
> +                                       msg->msg_flags &= ~MSG_CTRUNC;
> +                                       if (!err)
> +                                               err = -ETOOSMALL;
> +                                       goto out;
> +                               }
> +
> +                               atomic_long_inc(&niov->pp_ref_count);
> +                               tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag);
> +
> +                               sent += copy;
> +
> +                               if (remaining_len == 0)
> +                                       goto out;
> +                       }
> +                       start = end;
> +               }
> +
> +               tcp_xa_pool_commit(sk, &tcp_xa_pool);
> +               if (!remaining_len)
> +                       goto out;
> +
> +               /* if remaining_len is not satisfied yet, we need to go to the
> +                * next frag in the frag_list to satisfy remaining_len.
> +                */
> +               skb = skb_shinfo(skb)->frag_list ?: skb->next;
> +
> +               offset = 0;

If the offset is 5000 and only 4500 bytes are skipped at this point,
the offset should be 500, not 0.
We need to add a condition to set the offset correctly.

> +       } while (skb);
> +
> +       if (remaining_len) {
> +               err = -EFAULT;
> +               goto out;
> +       }
> +
> +out:
> +       tcp_xa_pool_commit(sk, &tcp_xa_pool);
> +       if (!sent)
> +               sent = err;
> +
> +       return sent;
> +}
> +
>  /*
>   *     This routine copies from a sock struct into the user buffer.
>   *
> @@ -2336,6 +2551,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
>                               int *cmsg_flags)
>  {
>         struct tcp_sock *tp = tcp_sk(sk);
> +       int last_copied_dmabuf = -1; /* uninitialized */
>         int copied = 0;
>         u32 peek_seq;
>         u32 *seq;
> @@ -2515,15 +2731,44 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
>                 }
>
>                 if (!(flags & MSG_TRUNC)) {
> -                       err = skb_copy_datagram_msg(skb, offset, msg, used);
> -                       if (err) {
> -                               /* Exception. Bailout! */
> -                               if (!copied)
> -                                       copied = -EFAULT;
> +                       if (last_copied_dmabuf != -1 &&
> +                           last_copied_dmabuf != !skb_frags_readable(skb))
>                                 break;
> +
> +                       if (skb_frags_readable(skb)) {
> +                               err = skb_copy_datagram_msg(skb, offset, msg,
> +                                                           used);
> +                               if (err) {
> +                                       /* Exception. Bailout! */
> +                                       if (!copied)
> +                                               copied = -EFAULT;
> +                                       break;
> +                               }
> +                       } else {
> +                               if (!(flags & MSG_SOCK_DEVMEM)) {
> +                                       /* dmabuf skbs can only be received
> +                                        * with the MSG_SOCK_DEVMEM flag.
> +                                        */
> +                                       if (!copied)
> +                                               copied = -EFAULT;
> +
> +                                       break;
> +                               }
> +
> +                               err = tcp_recvmsg_dmabuf(sk, skb, offset, msg,
> +                                                        used);
> +                               if (err <= 0) {
> +                                       if (!copied)
> +                                               copied = -EFAULT;
> +
> +                                       break;
> +                               }
> +                               used = err;
>                         }
>                 }
>
> +               last_copied_dmabuf = !skb_frags_readable(skb);
> +
>                 WRITE_ONCE(*seq, *seq + used);
>                 copied += used;
>                 len -= used;
> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> index fd17f25ff288..f3b2ae0823c4 100644
> --- a/net/ipv4/tcp_ipv4.c
> +++ b/net/ipv4/tcp_ipv4.c
> @@ -79,6 +79,7 @@
>  #include <linux/seq_file.h>
>  #include <linux/inetdevice.h>
>  #include <linux/btf_ids.h>
> +#include <linux/skbuff_ref.h>
>
>  #include <crypto/hash.h>
>  #include <linux/scatterlist.h>
> @@ -2507,10 +2508,25 @@ static void tcp_md5sig_info_free_rcu(struct rcu_head *head)
>  }
>  #endif
>
> +static void tcp_release_user_frags(struct sock *sk)
> +{
> +#ifdef CONFIG_PAGE_POOL
> +       unsigned long index;
> +       void *netmem;
> +
> +       xa_for_each(&sk->sk_user_frags, index, netmem)
> +               WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem));
> +#endif
> +}
> +
>  void tcp_v4_destroy_sock(struct sock *sk)
>  {
>         struct tcp_sock *tp = tcp_sk(sk);
>
> +       tcp_release_user_frags(sk);
> +
> +       xa_destroy(&sk->sk_user_frags);
> +
>         trace_tcp_destroy_sock(sk);
>
>         tcp_clear_xmit_timers(sk);
> diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
> index a19a9dbd3409..9ab87a41255d 100644
> --- a/net/ipv4/tcp_minisocks.c
> +++ b/net/ipv4/tcp_minisocks.c
> @@ -625,6 +625,8 @@ struct sock *tcp_create_openreq_child(const struct sock *sk,
>
>         __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS);
>
> +       xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1);
> +
>         return newsk;
>  }
>  EXPORT_SYMBOL(tcp_create_openreq_child);
> --
> 2.46.0.76.ge559c4bf1a-goog
>

I have been testing with modified code like below, it has been working
correctly for 24+ hours.
This modification is only for simple testing.
So, could you please look into this problem?

diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 40e7335dae6e..b9df6ac28477 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -2396,9 +2396,11 @@ static int tcp_recvmsg_dmabuf(struct sock *sk,
const struct sk_buff *skb,
       struct dmabuf_cmsg dmabuf_cmsg = { 0 };
       struct tcp_xa_pool tcp_xa_pool;
       unsigned int start;
+       int skip = offset;
       int i, copy, n;
       int sent = 0;
       int err = 0;
+       int end = 0;

       tcp_xa_pool.max = 0;
       tcp_xa_pool.idx = 0;
@@ -2452,7 +2454,6 @@ static int tcp_recvmsg_dmabuf(struct sock *sk,
const struct sk_buff *skb,
                       skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
                       struct net_iov *niov;
                       u64 frag_offset;
-                       int end;

                       /* !skb_frags_readable() should indicate that ALL the
                        * frags in this skb are dmabuf net_iovs. We're checking
@@ -2522,7 +2523,14 @@ static int tcp_recvmsg_dmabuf(struct sock *sk,
const struct sk_buff *skb,
                */
               skb = skb_shinfo(skb)->frag_list ?: skb->next;

-               offset = 0;
+               if (skip > 0) {
+                       skip -= end;
+                       offset = skip;
+               }
+               if (skip <= 0) {
+                       offset = 0;
+                       skip = 0;
+               }
       } while (skb);

       if (remaining_len) {

Thanks a lot!
Taehee Yoo
Mina Almasry Aug. 17, 2024, 3:13 p.m. UTC | #2
On Sat, Aug 17, 2024 at 9:58 AM Taehee Yoo <ap420073@gmail.com> wrote:
>
> On Wed, Aug 14, 2024 at 6:13 AM Mina Almasry <almasrymina@google.com> wrote:
> >
>
> Hi Mina,
>
> > In tcp_recvmsg_locked(), detect if the skb being received by the user
> > is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
> > flag - pass it to tcp_recvmsg_devmem() for custom handling.
> >
> > tcp_recvmsg_devmem() copies any data in the skb header to the linear
> > buffer, and returns a cmsg to the user indicating the number of bytes
> > returned in the linear buffer.
> >
> > tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
> > and returns to the user a cmsg_devmem indicating the location of the
> > data in the dmabuf device memory. cmsg_devmem contains this information:
> >
> > 1. the offset into the dmabuf where the payload starts. 'frag_offset'.
>
> I have been testing this patch and I found a bug.

Thanks Taehee. It's exciting to see that you have gotten this far in
your testing!! You seem to have devmem TCP (almost) fully working!!
May I ask which driver this is? I assume it's bnxt. Do you have the
driver support somewhere on github or something? I'm curious what your
driver implementation looks like.

> While testing it with the ncdevmem cmd, it fails to validate buffers
> after some period.
> This is because tcp_recvmsg_dmabuf() can't handle skb properly when
> the parameter offset != 0.

Sadly I'm unable to reproduce this issue, but I think I know where to
suspect the bug is. Thanks for taking the time to root cause this and
provide a fix.

...

> > +               offset = 0;
>
> If the offset is 5000 and only 4500 bytes are skipped at this point,
> the offset should be 500, not 0.
> We need to add a condition to set the offset correctly.
>

I highly suspect this is a regression that was introduced in v13. In
v12 Pavel asked if offset can just be set to 0 here, and I didn't see
any reason why not, so I made the change:

    -+          offset = offset - start;
    ++          offset = 0;

It looks like we missed something. I suspect reverting that may
resolve the issue, because __skb_copy_datagram() in earlier kernels
modified offset like this and it's well tested. Can you test with this
change reverted? Diff like so:

diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 40e7335dae6e..984e28c5d096 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -2522,7 +2522,7 @@ static int tcp_recvmsg_dmabuf(struct sock *sk,
const struct sk_buff *skb,
                 */
                skb = skb_shinfo(skb)->frag_list ?: skb->next;

-               offset = 0;
+               offset = offset - start;
        } while (skb);

        if (remaining_len) {

I'm running a long test to try to reproduce this issue, but I have ran
many long tests before and was not able to. For some reason my setup
is not able to reproduce this edge case. Are you doing anything
special with ncdevmem? Or simply running commands like these on the
server client?

server: ./ncdevmem -s SERVER -c CLIENT -l -p 5224 -v 7
client: yes $(echo -e \\x01\\x02\\x03\\x04\\x05\\x06) | tr \\n \\0 |
nc SERVER 5224 -p 5224
Taehee Yoo Aug. 17, 2024, 4:51 p.m. UTC | #3
On Sun, Aug 18, 2024 at 12:13 AM Mina Almasry <almasrymina@google.com> wrote:
>
> On Sat, Aug 17, 2024 at 9:58 AM Taehee Yoo <ap420073@gmail.com> wrote:
> >
> > On Wed, Aug 14, 2024 at 6:13 AM Mina Almasry <almasrymina@google.com> wrote:
> > >
> >
> > Hi Mina,
> >
> > > In tcp_recvmsg_locked(), detect if the skb being received by the user
> > > is a devmem skb. In this case - if the user provided the MSG_SOCK_DEVMEM
> > > flag - pass it to tcp_recvmsg_devmem() for custom handling.
> > >
> > > tcp_recvmsg_devmem() copies any data in the skb header to the linear
> > > buffer, and returns a cmsg to the user indicating the number of bytes
> > > returned in the linear buffer.
> > >
> > > tcp_recvmsg_devmem() then loops over the unaccessible devmem skb frags,
> > > and returns to the user a cmsg_devmem indicating the location of the
> > > data in the dmabuf device memory. cmsg_devmem contains this information:
> > >
> > > 1. the offset into the dmabuf where the payload starts. 'frag_offset'.
> >
> > I have been testing this patch and I found a bug.
>
> Thanks Taehee. It's exciting to see that you have gotten this far in
> your testing!! You seem to have devmem TCP (almost) fully working!!

Yes, Finally!!

> May I ask which driver this is? I assume it's bnxt. Do you have the
> driver support somewhere on github or something? I'm curious what your
> driver implementation looks like.

I'm using a modified bnxt_en driver.
NICs are BCM57412, BCM57508(currently only 57412).
I modified the driver by myself for devmem TCP.
The implementation is too rough, only for testing it.

>
> > While testing it with the ncdevmem cmd, it fails to validate buffers
> > after some period.
> > This is because tcp_recvmsg_dmabuf() can't handle skb properly when
> > the parameter offset != 0.
>
> Sadly I'm unable to reproduce this issue, but I think I know where to
> suspect the bug is. Thanks for taking the time to root cause this and
> provide a fix.

I'm not sure why, but I assume this bug appears when CPU utilization
peeks to almost 100%.

>
> ...
>
> > > + offset = 0;
> >
> > If the offset is 5000 and only 4500 bytes are skipped at this point,
> > the offset should be 500, not 0.
> > We need to add a condition to set the offset correctly.
> >
>
> I highly suspect this is a regression that was introduced in v13. In
> v12 Pavel asked if offset can just be set to 0 here, and I didn't see
> any reason why not, so I made the change:
>
> -+ offset = offset - start;
> ++ offset = 0;
>
> It looks like we missed something. I suspect reverting that may
> resolve the issue, because __skb_copy_datagram() in earlier kernels
> modified offset like this and it's well tested. Can you test with this
> change reverted? Diff like so:

No problem, I will test it.

>
> diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
> index 40e7335dae6e..984e28c5d096 100644
> --- a/net/ipv4/tcp.c
> +++ b/net/ipv4/tcp.c
> @@ -2522,7 +2522,7 @@ static int tcp_recvmsg_dmabuf(struct sock *sk,
> const struct sk_buff *skb,
> */
> skb = skb_shinfo(skb)->frag_list ?: skb->next;
>
> - offset = 0;
> + offset = offset - start;
> } while (skb);
>
> if (remaining_len) {
>
> I'm running a long test to try to reproduce this issue, but I have ran
> many long tests before and was not able to. For some reason my setup
> is not able to reproduce this edge case. Are you doing anything
> special with ncdevmem? Or simply running commands like these on the
> server client?
>
> server: ./ncdevmem -s SERVER -c CLIENT -l -p 5224 -v 7
> client: yes $(echo -e \\x01\\x02\\x03\\x04\\x05\\x06) | tr \\n \\0 |
> nc SERVER 5224 -p 5224
>

Yes, I'm using the vanilla ncdevmem command and option in the
documentation in this series.

server:
./ncdevmem -s 192.168.1.4 -c 192.168.1.2 -f $interface -l -p 5000 -v 7 -t 0 -q 4

client:
yes $(echo -e \\x01\\x02\\x03\\x04\\x05\\x06) | \
tr \\n \\0 | \
head -c 1000000G | \
nc 192.168.1.4 5000 -p 5000


BTW, this is a ncdevmem failure message of the current implementation.
(offset = 0)

received frag_page=13205, in_page_offset=0, frag_offset=54087680,
frag_size=2896, token=288, total_received=1153657976, dmabuf_id=2
Validated buffer
received frag_page=13204, in_page_offset=0, frag_offset=54083584,
frag_size=1448, token=289, total_received=1153659424, dmabuf_id=2
Validated buffer
[1] received frag_page=13203, in_page_offset=0, frag_offset=54079488,
frag_size=2896, token=290, total_received=1153662320, dmabuf_id=2
Validated buffer
received frag_page=13202, in_page_offset=0, frag_offset=54075392,
frag_size=200, token=291, total_received=1153662520, dmabuf_id=2
Validated buffer
total_received=1153662520


recvmsg ret=819200
[2] received frag_page=13203, in_page_offset=0, frag_offset=54079488,
frag_size=2896, token=1, total_received=1153665416, dmabuf_id=2
Failed validation: expected=4, actual=2, index=0
Failed validation: expected=5, actual=3, index=1
Failed validation: expected=6, actual=4, index=2
Failed validation: expected=0, actual=5, index=3
Failed validation: expected=1, actual=6, index=4
Failed validation: expected=2, actual=0, index=5
Failed validation: expected=3, actual=1, index=6
Failed validation: expected=4, actual=2, index=7
Failed validation: expected=5, actual=3, index=8
Failed validation: expected=6, actual=4, index=9
Failed validation: expected=0, actual=5, index=10
Failed validation: expected=1, actual=6, index=11
Failed validation: expected=2, actual=0, index=12
Failed validation: expected=3, actual=1, index=13
Failed validation: expected=4, actual=2, index=14
Failed validation: expected=5, actual=3, index=15
Failed validation: expected=6, actual=4, index=16
Failed validation: expected=0, actual=5, index=17
Failed validation: expected=1, actual=6, index=18
Failed validation: expected=2, actual=0, index=19
Failed validation: expected=3, actual=1, index=20
./ncdevmem: validation failed.

Please look at the [1] and [2].
At the [1], The 13203 page is fully passed to userspace.
But 13202 is not, only 200 bytes are passed to userspace.
The 13203 page is passed to userspace fully, but 13202 is not.
Only 200 bytes are passed to userspace.

But at the [2], it receives 2896 bytes from 13203 again.
It should be 13202 and in_page_offset=200.

And I just started testing your suggestion, it seems to work correctly.
Here is the ncdevmem message, it's not a failure message.
(offset = offset - start)

received frag_page=13085, in_page_offset=0, frag_offset=53596160,
frag_size=2896, token=288, total_received=2233699704, dmabuf_id=2
Validated buffer
received frag_page=12931, in_page_offset=0, frag_offset=52965376,
frag_size=2896, token=289, total_received=2233702600, dmabuf_id=2
Validated buffer
[1] received frag_page=12916, in_page_offset=0, frag_offset=52903936,
frag_size=1392, token=290, total_received=2233703992, dmabuf_id=2
Validated buffer
total_received=2233703992


recvmsg ret=819200
[2] received frag_page=12916, in_page_offset=1392,
frag_offset=52905328, frag_size=1504, token=1,
total_received=2233705496,
dmabuf_id=2
Validated buffer
received frag_page=13244, in_page_offset=0, frag_offset=54247424,
frag_size=2896, token=2, total_received=2233708392, dmabuf_id=2
Validated buffer
received frag_page=13579, in_page_offset=0, frag_offset=55619584,
frag_size=1448, token=3, total_received=2233709840, dmabuf_id=2
Validated buffer
received frag_page=12315, in_page_offset=0, frag_offset=50442240,
frag_size=2896, token=4, total_received=2233712736, dmabuf_id=2
Validated buffer

At the [1], the 12916 page was not passed to userspace fully.
Only 1392 bytes are received.
At the [2], remain 1504 bytes are passed and the offset is 1392.
So, the 12916 page size is 2896, so it makes sense.

So, this is the reason why I think your suggestion is working correctly.
I have been still testing it, so I will report if it fails while testing.
But I think it works well about this corner case so far.

>
> --
> Thanks,
> Mina

Thanks a lot,
Taehee Yoo
diff mbox series

Patch

diff --git a/arch/alpha/include/uapi/asm/socket.h b/arch/alpha/include/uapi/asm/socket.h
index e94f621903fe..ef4656a41058 100644
--- a/arch/alpha/include/uapi/asm/socket.h
+++ b/arch/alpha/include/uapi/asm/socket.h
@@ -140,6 +140,11 @@ 
 #define SO_PASSPIDFD		76
 #define SO_PEERPIDFD		77
 
+#define SO_DEVMEM_LINEAR	78
+#define SCM_DEVMEM_LINEAR	SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF	79
+#define SCM_DEVMEM_DMABUF	SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/mips/include/uapi/asm/socket.h b/arch/mips/include/uapi/asm/socket.h
index 60ebaed28a4c..414807d55e33 100644
--- a/arch/mips/include/uapi/asm/socket.h
+++ b/arch/mips/include/uapi/asm/socket.h
@@ -151,6 +151,11 @@ 
 #define SO_PASSPIDFD		76
 #define SO_PEERPIDFD		77
 
+#define SO_DEVMEM_LINEAR	78
+#define SCM_DEVMEM_LINEAR	SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF	79
+#define SCM_DEVMEM_DMABUF	SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/parisc/include/uapi/asm/socket.h b/arch/parisc/include/uapi/asm/socket.h
index be264c2b1a11..2b817efd4544 100644
--- a/arch/parisc/include/uapi/asm/socket.h
+++ b/arch/parisc/include/uapi/asm/socket.h
@@ -132,6 +132,11 @@ 
 #define SO_PASSPIDFD		0x404A
 #define SO_PEERPIDFD		0x404B
 
+#define SO_DEVMEM_LINEAR	78
+#define SCM_DEVMEM_LINEAR	SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF	79
+#define SCM_DEVMEM_DMABUF	SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64
diff --git a/arch/sparc/include/uapi/asm/socket.h b/arch/sparc/include/uapi/asm/socket.h
index 682da3714686..00248fc68977 100644
--- a/arch/sparc/include/uapi/asm/socket.h
+++ b/arch/sparc/include/uapi/asm/socket.h
@@ -133,6 +133,11 @@ 
 #define SO_PASSPIDFD             0x0055
 #define SO_PEERPIDFD             0x0056
 
+#define SO_DEVMEM_LINEAR         0x0057
+#define SCM_DEVMEM_LINEAR        SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF         0x0058
+#define SCM_DEVMEM_DMABUF        SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 
diff --git a/include/linux/socket.h b/include/linux/socket.h
index df9cdb8bbfb8..d18cc47e89bd 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -327,6 +327,7 @@  struct ucred {
 					  * plain text and require encryption
 					  */
 
+#define MSG_SOCK_DEVMEM 0x2000000	/* Receive devmem skbs as cmsg */
 #define MSG_ZEROCOPY	0x4000000	/* Use user data in kernel path */
 #define MSG_SPLICE_PAGES 0x8000000	/* Splice the pages from the iterator in sendmsg() */
 #define MSG_FASTOPEN	0x20000000	/* Send data in TCP SYN */
diff --git a/include/net/netmem.h b/include/net/netmem.h
index 284f84a312c2..84043fbdd797 100644
--- a/include/net/netmem.h
+++ b/include/net/netmem.h
@@ -65,6 +65,19 @@  static inline unsigned int net_iov_idx(const struct net_iov *niov)
 	return niov - net_iov_owner(niov)->niovs;
 }
 
+static inline unsigned long net_iov_virtual_addr(const struct net_iov *niov)
+{
+	struct dmabuf_genpool_chunk_owner *owner = net_iov_owner(niov);
+
+	return owner->base_virtual +
+	       ((unsigned long)net_iov_idx(niov) << PAGE_SHIFT);
+}
+
+static inline u32 net_iov_binding_id(const struct net_iov *niov)
+{
+	return net_iov_owner(niov)->binding->id;
+}
+
 static inline struct net_devmem_dmabuf_binding *
 net_iov_binding(const struct net_iov *niov)
 {
diff --git a/include/net/sock.h b/include/net/sock.h
index cce23ac4d514..f8ec869be238 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -337,6 +337,7 @@  struct sk_filter;
   *	@sk_txtime_report_errors: set report errors mode for SO_TXTIME
   *	@sk_txtime_unused: unused txtime flags
   *	@ns_tracker: tracker for netns reference
+  *	@sk_user_frags: xarray of pages the user is holding a reference on.
   */
 struct sock {
 	/*
@@ -542,6 +543,7 @@  struct sock {
 #endif
 	struct rcu_head		sk_rcu;
 	netns_tracker		ns_tracker;
+	struct xarray		sk_user_frags;
 };
 
 struct sock_bh_locked {
diff --git a/include/uapi/asm-generic/socket.h b/include/uapi/asm-generic/socket.h
index 8ce8a39a1e5f..e993edc9c0ee 100644
--- a/include/uapi/asm-generic/socket.h
+++ b/include/uapi/asm-generic/socket.h
@@ -135,6 +135,11 @@ 
 #define SO_PASSPIDFD		76
 #define SO_PEERPIDFD		77
 
+#define SO_DEVMEM_LINEAR	78
+#define SCM_DEVMEM_LINEAR	SO_DEVMEM_LINEAR
+#define SO_DEVMEM_DMABUF	79
+#define SCM_DEVMEM_DMABUF	SO_DEVMEM_DMABUF
+
 #if !defined(__KERNEL__)
 
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
diff --git a/include/uapi/linux/uio.h b/include/uapi/linux/uio.h
index 059b1a9147f4..3a22ddae376a 100644
--- a/include/uapi/linux/uio.h
+++ b/include/uapi/linux/uio.h
@@ -20,6 +20,19 @@  struct iovec
 	__kernel_size_t iov_len; /* Must be size_t (1003.1g) */
 };
 
+struct dmabuf_cmsg {
+	__u64 frag_offset;	/* offset into the dmabuf where the frag starts.
+				 */
+	__u32 frag_size;	/* size of the frag. */
+	__u32 frag_token;	/* token representing this frag for
+				 * DEVMEM_DONTNEED.
+				 */
+	__u32  dmabuf_id;	/* dmabuf id this frag belongs to. */
+	__u32 flags;		/* Currently unused. Reserved for future
+				 * uses.
+				 */
+};
+
 /*
  *	UIO_MAXIOV shall be at least 16 1003.1g (5.4.1.1)
  */
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 30e0aa38ba9b..40e7335dae6e 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -471,6 +471,7 @@  void tcp_init_sock(struct sock *sk)
 
 	set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
 	sk_sockets_allocated_inc(sk);
+	xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1);
 }
 EXPORT_SYMBOL(tcp_init_sock);
 
@@ -2323,6 +2324,220 @@  static int tcp_inq_hint(struct sock *sk)
 	return inq;
 }
 
+/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */
+struct tcp_xa_pool {
+	u8		max; /* max <= MAX_SKB_FRAGS */
+	u8		idx; /* idx <= max */
+	__u32		tokens[MAX_SKB_FRAGS];
+	netmem_ref	netmems[MAX_SKB_FRAGS];
+};
+
+static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p)
+{
+	int i;
+
+	/* Commit part that has been copied to user space. */
+	for (i = 0; i < p->idx; i++)
+		__xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY,
+			     (__force void *)p->netmems[i], GFP_KERNEL);
+	/* Rollback what has been pre-allocated and is no longer needed. */
+	for (; i < p->max; i++)
+		__xa_erase(&sk->sk_user_frags, p->tokens[i]);
+
+	p->max = 0;
+	p->idx = 0;
+}
+
+static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p)
+{
+	if (!p->max)
+		return;
+
+	xa_lock_bh(&sk->sk_user_frags);
+
+	tcp_xa_pool_commit_locked(sk, p);
+
+	xa_unlock_bh(&sk->sk_user_frags);
+}
+
+static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p,
+			      unsigned int max_frags)
+{
+	int err, k;
+
+	if (p->idx < p->max)
+		return 0;
+
+	xa_lock_bh(&sk->sk_user_frags);
+
+	tcp_xa_pool_commit_locked(sk, p);
+
+	for (k = 0; k < max_frags; k++) {
+		err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k],
+				 XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL);
+		if (err)
+			break;
+	}
+
+	xa_unlock_bh(&sk->sk_user_frags);
+
+	p->max = k;
+	p->idx = 0;
+	return k ? 0 : err;
+}
+
+/* On error, returns the -errno. On success, returns number of bytes sent to the
+ * user. May not consume all of @remaining_len.
+ */
+static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb,
+			      unsigned int offset, struct msghdr *msg,
+			      int remaining_len)
+{
+	struct dmabuf_cmsg dmabuf_cmsg = { 0 };
+	struct tcp_xa_pool tcp_xa_pool;
+	unsigned int start;
+	int i, copy, n;
+	int sent = 0;
+	int err = 0;
+
+	tcp_xa_pool.max = 0;
+	tcp_xa_pool.idx = 0;
+	do {
+		start = skb_headlen(skb);
+
+		if (skb_frags_readable(skb)) {
+			err = -ENODEV;
+			goto out;
+		}
+
+		/* Copy header. */
+		copy = start - offset;
+		if (copy > 0) {
+			copy = min(copy, remaining_len);
+
+			n = copy_to_iter(skb->data + offset, copy,
+					 &msg->msg_iter);
+			if (n != copy) {
+				err = -EFAULT;
+				goto out;
+			}
+
+			offset += copy;
+			remaining_len -= copy;
+
+			/* First a dmabuf_cmsg for # bytes copied to user
+			 * buffer.
+			 */
+			memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg));
+			dmabuf_cmsg.frag_size = copy;
+			err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR,
+				       sizeof(dmabuf_cmsg), &dmabuf_cmsg);
+			if (err || msg->msg_flags & MSG_CTRUNC) {
+				msg->msg_flags &= ~MSG_CTRUNC;
+				if (!err)
+					err = -ETOOSMALL;
+				goto out;
+			}
+
+			sent += copy;
+
+			if (remaining_len == 0)
+				goto out;
+		}
+
+		/* after that, send information of dmabuf pages through a
+		 * sequence of cmsg
+		 */
+		for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+			skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
+			struct net_iov *niov;
+			u64 frag_offset;
+			int end;
+
+			/* !skb_frags_readable() should indicate that ALL the
+			 * frags in this skb are dmabuf net_iovs. We're checking
+			 * for that flag above, but also check individual frags
+			 * here. If the tcp stack is not setting
+			 * skb_frags_readable() correctly, we still don't want
+			 * to crash here.
+			 */
+			if (!skb_frag_net_iov(frag)) {
+				net_err_ratelimited("Found non-dmabuf skb with net_iov");
+				err = -ENODEV;
+				goto out;
+			}
+
+			niov = skb_frag_net_iov(frag);
+			end = start + skb_frag_size(frag);
+			copy = end - offset;
+
+			if (copy > 0) {
+				copy = min(copy, remaining_len);
+
+				frag_offset = net_iov_virtual_addr(niov) +
+					      skb_frag_off(frag) + offset -
+					      start;
+				dmabuf_cmsg.frag_offset = frag_offset;
+				dmabuf_cmsg.frag_size = copy;
+				err = tcp_xa_pool_refill(sk, &tcp_xa_pool,
+							 skb_shinfo(skb)->nr_frags - i);
+				if (err)
+					goto out;
+
+				/* Will perform the exchange later */
+				dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx];
+				dmabuf_cmsg.dmabuf_id = net_iov_binding_id(niov);
+
+				offset += copy;
+				remaining_len -= copy;
+
+				err = put_cmsg(msg, SOL_SOCKET,
+					       SO_DEVMEM_DMABUF,
+					       sizeof(dmabuf_cmsg),
+					       &dmabuf_cmsg);
+				if (err || msg->msg_flags & MSG_CTRUNC) {
+					msg->msg_flags &= ~MSG_CTRUNC;
+					if (!err)
+						err = -ETOOSMALL;
+					goto out;
+				}
+
+				atomic_long_inc(&niov->pp_ref_count);
+				tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag);
+
+				sent += copy;
+
+				if (remaining_len == 0)
+					goto out;
+			}
+			start = end;
+		}
+
+		tcp_xa_pool_commit(sk, &tcp_xa_pool);
+		if (!remaining_len)
+			goto out;
+
+		/* if remaining_len is not satisfied yet, we need to go to the
+		 * next frag in the frag_list to satisfy remaining_len.
+		 */
+		skb = skb_shinfo(skb)->frag_list ?: skb->next;
+
+		offset = 0;
+	} while (skb);
+
+	if (remaining_len) {
+		err = -EFAULT;
+		goto out;
+	}
+
+out:
+	tcp_xa_pool_commit(sk, &tcp_xa_pool);
+	if (!sent)
+		sent = err;
+
+	return sent;
+}
+
 /*
  *	This routine copies from a sock struct into the user buffer.
  *
@@ -2336,6 +2551,7 @@  static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
 			      int *cmsg_flags)
 {
 	struct tcp_sock *tp = tcp_sk(sk);
+	int last_copied_dmabuf = -1; /* uninitialized */
 	int copied = 0;
 	u32 peek_seq;
 	u32 *seq;
@@ -2515,15 +2731,44 @@  static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
 		}
 
 		if (!(flags & MSG_TRUNC)) {
-			err = skb_copy_datagram_msg(skb, offset, msg, used);
-			if (err) {
-				/* Exception. Bailout! */
-				if (!copied)
-					copied = -EFAULT;
+			if (last_copied_dmabuf != -1 &&
+			    last_copied_dmabuf != !skb_frags_readable(skb))
 				break;
+
+			if (skb_frags_readable(skb)) {
+				err = skb_copy_datagram_msg(skb, offset, msg,
+							    used);
+				if (err) {
+					/* Exception. Bailout! */
+					if (!copied)
+						copied = -EFAULT;
+					break;
+				}
+			} else {
+				if (!(flags & MSG_SOCK_DEVMEM)) {
+					/* dmabuf skbs can only be received
+					 * with the MSG_SOCK_DEVMEM flag.
+					 */
+					if (!copied)
+						copied = -EFAULT;
+
+					break;
+				}
+
+				err = tcp_recvmsg_dmabuf(sk, skb, offset, msg,
+							 used);
+				if (err <= 0) {
+					if (!copied)
+						copied = -EFAULT;
+
+					break;
+				}
+				used = err;
 			}
 		}
 
+		last_copied_dmabuf = !skb_frags_readable(skb);
+
 		WRITE_ONCE(*seq, *seq + used);
 		copied += used;
 		len -= used;
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index fd17f25ff288..f3b2ae0823c4 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -79,6 +79,7 @@ 
 #include <linux/seq_file.h>
 #include <linux/inetdevice.h>
 #include <linux/btf_ids.h>
+#include <linux/skbuff_ref.h>
 
 #include <crypto/hash.h>
 #include <linux/scatterlist.h>
@@ -2507,10 +2508,25 @@  static void tcp_md5sig_info_free_rcu(struct rcu_head *head)
 }
 #endif
 
+static void tcp_release_user_frags(struct sock *sk)
+{
+#ifdef CONFIG_PAGE_POOL
+	unsigned long index;
+	void *netmem;
+
+	xa_for_each(&sk->sk_user_frags, index, netmem)
+		WARN_ON_ONCE(!napi_pp_put_page((__force netmem_ref)netmem));
+#endif
+}
+
 void tcp_v4_destroy_sock(struct sock *sk)
 {
 	struct tcp_sock *tp = tcp_sk(sk);
 
+	tcp_release_user_frags(sk);
+
+	xa_destroy(&sk->sk_user_frags);
+
 	trace_tcp_destroy_sock(sk);
 
 	tcp_clear_xmit_timers(sk);
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index a19a9dbd3409..9ab87a41255d 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -625,6 +625,8 @@  struct sock *tcp_create_openreq_child(const struct sock *sk,
 
 	__TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS);
 
+	xa_init_flags(&newsk->sk_user_frags, XA_FLAGS_ALLOC1);
+
 	return newsk;
 }
 EXPORT_SYMBOL(tcp_create_openreq_child);