diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c index fa8fbb8fa3c8..9044073fbf22 100644 --- a/net/xdp/xsk.c +++ b/net/xdp/xsk.c @@ -305,9 +305,8 @@ out: } EXPORT_SYMBOL(xsk_umem_consume_tx); -static int xsk_zc_xmit(struct sock *sk) +static int xsk_zc_xmit(struct xdp_sock *xs) { - struct xdp_sock *xs = xdp_sk(sk); struct net_device *dev = xs->dev; return dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, @@ -327,11 +326,10 @@ static void xsk_destruct_skb(struct sk_buff *skb) sock_wfree(skb); } -static int xsk_generic_xmit(struct sock *sk, struct msghdr *m, - size_t total_len) +static int xsk_generic_xmit(struct sock *sk) { - u32 max_batch = TX_BATCH_SIZE; struct xdp_sock *xs = xdp_sk(sk); + u32 max_batch = TX_BATCH_SIZE; bool sent_frame = false; struct xdp_desc desc; struct sk_buff *skb; @@ -394,6 +392,18 @@ out: return err; } +static int __xsk_sendmsg(struct sock *sk) +{ + struct xdp_sock *xs = xdp_sk(sk); + + if (unlikely(!(xs->dev->flags & IFF_UP))) + return -ENETDOWN; + if (unlikely(!xs->tx)) + return -ENOBUFS; + + return xs->zc ? xsk_zc_xmit(xs) : xsk_generic_xmit(sk); +} + static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) { bool need_wait = !(m->msg_flags & MSG_DONTWAIT); @@ -402,21 +412,18 @@ static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) if (unlikely(!xsk_is_bound(xs))) return -ENXIO; - if (unlikely(!(xs->dev->flags & IFF_UP))) - return -ENETDOWN; - if (unlikely(!xs->tx)) - return -ENOBUFS; - if (need_wait) + if (unlikely(need_wait)) return -EOPNOTSUPP; - return (xs->zc) ? xsk_zc_xmit(sk) : xsk_generic_xmit(sk, m, total_len); + return __xsk_sendmsg(sk); } static unsigned int xsk_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { unsigned int mask = datagram_poll(file, sock, wait); - struct xdp_sock *xs = xdp_sk(sock->sk); + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); struct net_device *dev; struct xdp_umem *umem; @@ -426,9 +433,14 @@ static unsigned int xsk_poll(struct file *file, struct socket *sock, dev = xs->dev; umem = xs->umem; - if (umem->need_wakeup) - dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, - umem->need_wakeup); + if (umem->need_wakeup) { + if (dev->netdev_ops->ndo_xsk_wakeup) + dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, + umem->need_wakeup); + else + /* Poll needs to drive Tx also in copy mode */ + __xsk_sendmsg(sk); + } if (xs->rx && !xskq_empty_desc(xs->rx)) mask |= POLLIN | POLLRDNORM;