@@ -403,18 +403,8 @@ EXPORT_SYMBOL(xsk_tx_peek_release_desc_batch);
static int xsk_wakeup(struct xdp_sock *xs, u8 flags)
{
struct net_device *dev = xs->dev;
- int err;
-
- rcu_read_lock();
- err = dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, flags);
- rcu_read_unlock();
-
- return err;
-}
-static int xsk_zc_xmit(struct xdp_sock *xs)
-{
- return xsk_wakeup(xs, XDP_WAKEUP_TX);
+ return dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, flags);
}
static void xsk_destruct_skb(struct sk_buff *skb)
@@ -533,6 +523,12 @@ static int xsk_generic_xmit(struct sock *sk)
mutex_lock(&xs->mutex);
+ /* Since we dropped the RCU read lock, the socket state might have changed. */
+ if (unlikely(!xsk_is_bound(xs))) {
+ err = -ENXIO;
+ goto out;
+ }
+
if (xs->queue_id >= xs->dev->real_num_tx_queues)
goto out;
@@ -596,16 +592,26 @@ static int xsk_generic_xmit(struct sock *sk)
return err;
}
-static int __xsk_sendmsg(struct sock *sk)
+static int xsk_xmit(struct sock *sk)
{
struct xdp_sock *xs = xdp_sk(sk);
+ int ret;
if (unlikely(!(xs->dev->flags & IFF_UP)))
return -ENETDOWN;
if (unlikely(!xs->tx))
return -ENOBUFS;
- return xs->zc ? xsk_zc_xmit(xs) : xsk_generic_xmit(sk);
+ if (xs->zc)
+ return xsk_wakeup(xs, XDP_WAKEUP_TX);
+
+ /* Drop the RCU lock since the SKB path might sleep. */
+ rcu_read_unlock();
+ ret = xsk_generic_xmit(sk);
+ /* Reaquire RCU lock before going into common code. */
+ rcu_read_lock();
+
+ return ret;
}
static bool xsk_no_wakeup(struct sock *sk)
@@ -619,7 +625,7 @@ static bool xsk_no_wakeup(struct sock *sk)
#endif
}
-static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
+static int __xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
{
bool need_wait = !(m->msg_flags & MSG_DONTWAIT);
struct sock *sk = sock->sk;
@@ -639,11 +645,22 @@ static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
pool = xs->pool;
if (pool->cached_need_wakeup & XDP_WAKEUP_TX)
- return __xsk_sendmsg(sk);
+ return xsk_xmit(sk);
return 0;
}
-static int xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int flags)
+static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
+{
+ int ret;
+
+ rcu_read_lock();
+ ret = __xsk_sendmsg(sock, m, total_len);
+ rcu_read_unlock();
+
+ return ret;
+}
+
+static int __xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int flags)
{
bool need_wait = !(flags & MSG_DONTWAIT);
struct sock *sk = sock->sk;
@@ -669,6 +686,17 @@ static int xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int fl
return 0;
}
+static int xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int flags)
+{
+ int ret;
+
+ rcu_read_lock();
+ ret = __xsk_recvmsg(sock, m, len, flags);
+ rcu_read_unlock();
+
+ return ret;
+}
+
static __poll_t xsk_poll(struct file *file, struct socket *sock,
struct poll_table_struct *wait)
{
@@ -679,8 +707,11 @@ static __poll_t xsk_poll(struct file *file, struct socket *sock,
sock_poll_wait(file, sock, wait);
- if (unlikely(!xsk_is_bound(xs)))
+ rcu_read_lock();
+ if (unlikely(!xsk_is_bound(xs))) {
+ rcu_read_unlock();
return mask;
+ }
pool = xs->pool;
@@ -689,7 +720,7 @@ static __poll_t xsk_poll(struct file *file, struct socket *sock,
xsk_wakeup(xs, pool->cached_need_wakeup);
else
/* Poll needs to drive Tx also in copy mode */
- __xsk_sendmsg(sk);
+ xsk_xmit(sk);
}
if (xs->rx && !xskq_prod_is_empty(xs->rx))
@@ -697,6 +728,7 @@ static __poll_t xsk_poll(struct file *file, struct socket *sock,
if (xs->tx && xsk_tx_writeable(xs))
mask |= EPOLLOUT | EPOLLWRNORM;
+ rcu_read_unlock();
return mask;
}
@@ -728,7 +760,6 @@ static void xsk_unbind_dev(struct xdp_sock *xs)
/* Wait for driver to stop using the xdp socket. */
xp_del_xsk(xs->pool, xs);
- xs->dev = NULL;
synchronize_net();
dev_put(dev);
}