skmsg: Extract __tcp_bpf_recvmsg() and tcp_bpf_wait_data()
Although these two functions are only used by TCP, they are not specific to TCP at all, both operate on skmsg and ingress_msg, so fit in net/core/skmsg.c very well. And we will need them for non-TCP, so rename and move them to skmsg.c and export them to modules. Signed-off-by: Cong Wang <cong.wang@bytedance.com> Signed-off-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/bpf/20210331023237.41094-13-xiyou.wangcong@gmail.com
This commit is contained in:
parent
d7f571188e
commit
2bc793e327
|
@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
|
||||||
struct sk_msg *msg, u32 bytes);
|
struct sk_msg *msg, u32 bytes);
|
||||||
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
|
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
|
||||||
struct sk_msg *msg, u32 bytes);
|
struct sk_msg *msg, u32 bytes);
|
||||||
|
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
|
||||||
|
long timeo, int *err);
|
||||||
|
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
|
||||||
|
int len, int flags);
|
||||||
|
|
||||||
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
|
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
|
||||||
{
|
{
|
||||||
|
|
|
@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
|
||||||
|
|
||||||
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
|
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
|
||||||
int flags);
|
int flags);
|
||||||
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
|
|
||||||
struct msghdr *msg, int len, int flags);
|
|
||||||
#endif /* CONFIG_NET_SOCK_MSG */
|
#endif /* CONFIG_NET_SOCK_MSG */
|
||||||
|
|
||||||
#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
|
#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
|
||||||
|
|
|
@ -399,6 +399,104 @@ out:
|
||||||
}
|
}
|
||||||
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
|
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
|
||||||
|
|
||||||
|
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
|
||||||
|
long timeo, int *err)
|
||||||
|
{
|
||||||
|
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
||||||
|
int ret = 0;
|
||||||
|
|
||||||
|
if (sk->sk_shutdown & RCV_SHUTDOWN)
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
if (!timeo)
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
add_wait_queue(sk_sleep(sk), &wait);
|
||||||
|
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||||
|
ret = sk_wait_event(sk, &timeo,
|
||||||
|
!list_empty(&psock->ingress_msg) ||
|
||||||
|
!skb_queue_empty(&sk->sk_receive_queue), &wait);
|
||||||
|
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
||||||
|
remove_wait_queue(sk_sleep(sk), &wait);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
EXPORT_SYMBOL_GPL(sk_msg_wait_data);
|
||||||
|
|
||||||
|
/* Receive sk_msg from psock->ingress_msg to @msg. */
|
||||||
|
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
|
||||||
|
int len, int flags)
|
||||||
|
{
|
||||||
|
struct iov_iter *iter = &msg->msg_iter;
|
||||||
|
int peek = flags & MSG_PEEK;
|
||||||
|
struct sk_msg *msg_rx;
|
||||||
|
int i, copied = 0;
|
||||||
|
|
||||||
|
msg_rx = sk_psock_peek_msg(psock);
|
||||||
|
while (copied != len) {
|
||||||
|
struct scatterlist *sge;
|
||||||
|
|
||||||
|
if (unlikely(!msg_rx))
|
||||||
|
break;
|
||||||
|
|
||||||
|
i = msg_rx->sg.start;
|
||||||
|
do {
|
||||||
|
struct page *page;
|
||||||
|
int copy;
|
||||||
|
|
||||||
|
sge = sk_msg_elem(msg_rx, i);
|
||||||
|
copy = sge->length;
|
||||||
|
page = sg_page(sge);
|
||||||
|
if (copied + copy > len)
|
||||||
|
copy = len - copied;
|
||||||
|
copy = copy_page_to_iter(page, sge->offset, copy, iter);
|
||||||
|
if (!copy)
|
||||||
|
return copied ? copied : -EFAULT;
|
||||||
|
|
||||||
|
copied += copy;
|
||||||
|
if (likely(!peek)) {
|
||||||
|
sge->offset += copy;
|
||||||
|
sge->length -= copy;
|
||||||
|
if (!msg_rx->skb)
|
||||||
|
sk_mem_uncharge(sk, copy);
|
||||||
|
msg_rx->sg.size -= copy;
|
||||||
|
|
||||||
|
if (!sge->length) {
|
||||||
|
sk_msg_iter_var_next(i);
|
||||||
|
if (!msg_rx->skb)
|
||||||
|
put_page(page);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* Lets not optimize peek case if copy_page_to_iter
|
||||||
|
* didn't copy the entire length lets just break.
|
||||||
|
*/
|
||||||
|
if (copy != sge->length)
|
||||||
|
return copied;
|
||||||
|
sk_msg_iter_var_next(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (copied == len)
|
||||||
|
break;
|
||||||
|
} while (i != msg_rx->sg.end);
|
||||||
|
|
||||||
|
if (unlikely(peek)) {
|
||||||
|
msg_rx = sk_psock_next_msg(psock, msg_rx);
|
||||||
|
if (!msg_rx)
|
||||||
|
break;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
msg_rx->sg.start = i;
|
||||||
|
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
|
||||||
|
msg_rx = sk_psock_dequeue_msg(psock);
|
||||||
|
kfree_sk_msg(msg_rx);
|
||||||
|
}
|
||||||
|
msg_rx = sk_psock_peek_msg(psock);
|
||||||
|
}
|
||||||
|
|
||||||
|
return copied;
|
||||||
|
}
|
||||||
|
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
|
||||||
|
|
||||||
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
|
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
|
||||||
struct sk_buff *skb)
|
struct sk_buff *skb)
|
||||||
{
|
{
|
||||||
|
|
|
@ -10,80 +10,6 @@
|
||||||
#include <net/inet_common.h>
|
#include <net/inet_common.h>
|
||||||
#include <net/tls.h>
|
#include <net/tls.h>
|
||||||
|
|
||||||
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
|
|
||||||
struct msghdr *msg, int len, int flags)
|
|
||||||
{
|
|
||||||
struct iov_iter *iter = &msg->msg_iter;
|
|
||||||
int peek = flags & MSG_PEEK;
|
|
||||||
struct sk_msg *msg_rx;
|
|
||||||
int i, copied = 0;
|
|
||||||
|
|
||||||
msg_rx = sk_psock_peek_msg(psock);
|
|
||||||
while (copied != len) {
|
|
||||||
struct scatterlist *sge;
|
|
||||||
|
|
||||||
if (unlikely(!msg_rx))
|
|
||||||
break;
|
|
||||||
|
|
||||||
i = msg_rx->sg.start;
|
|
||||||
do {
|
|
||||||
struct page *page;
|
|
||||||
int copy;
|
|
||||||
|
|
||||||
sge = sk_msg_elem(msg_rx, i);
|
|
||||||
copy = sge->length;
|
|
||||||
page = sg_page(sge);
|
|
||||||
if (copied + copy > len)
|
|
||||||
copy = len - copied;
|
|
||||||
copy = copy_page_to_iter(page, sge->offset, copy, iter);
|
|
||||||
if (!copy)
|
|
||||||
return copied ? copied : -EFAULT;
|
|
||||||
|
|
||||||
copied += copy;
|
|
||||||
if (likely(!peek)) {
|
|
||||||
sge->offset += copy;
|
|
||||||
sge->length -= copy;
|
|
||||||
if (!msg_rx->skb)
|
|
||||||
sk_mem_uncharge(sk, copy);
|
|
||||||
msg_rx->sg.size -= copy;
|
|
||||||
|
|
||||||
if (!sge->length) {
|
|
||||||
sk_msg_iter_var_next(i);
|
|
||||||
if (!msg_rx->skb)
|
|
||||||
put_page(page);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
/* Lets not optimize peek case if copy_page_to_iter
|
|
||||||
* didn't copy the entire length lets just break.
|
|
||||||
*/
|
|
||||||
if (copy != sge->length)
|
|
||||||
return copied;
|
|
||||||
sk_msg_iter_var_next(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copied == len)
|
|
||||||
break;
|
|
||||||
} while (i != msg_rx->sg.end);
|
|
||||||
|
|
||||||
if (unlikely(peek)) {
|
|
||||||
msg_rx = sk_psock_next_msg(psock, msg_rx);
|
|
||||||
if (!msg_rx)
|
|
||||||
break;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
msg_rx->sg.start = i;
|
|
||||||
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
|
|
||||||
msg_rx = sk_psock_dequeue_msg(psock);
|
|
||||||
kfree_sk_msg(msg_rx);
|
|
||||||
}
|
|
||||||
msg_rx = sk_psock_peek_msg(psock);
|
|
||||||
}
|
|
||||||
|
|
||||||
return copied;
|
|
||||||
}
|
|
||||||
EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
|
|
||||||
|
|
||||||
static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
|
static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
|
||||||
struct sk_msg *msg, u32 apply_bytes, int flags)
|
struct sk_msg *msg, u32 apply_bytes, int flags)
|
||||||
{
|
{
|
||||||
|
@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
|
||||||
return !empty;
|
return !empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
|
|
||||||
int flags, long timeo, int *err)
|
|
||||||
{
|
|
||||||
DEFINE_WAIT_FUNC(wait, woken_wake_function);
|
|
||||||
int ret = 0;
|
|
||||||
|
|
||||||
if (sk->sk_shutdown & RCV_SHUTDOWN)
|
|
||||||
return 1;
|
|
||||||
|
|
||||||
if (!timeo)
|
|
||||||
return ret;
|
|
||||||
|
|
||||||
add_wait_queue(sk_sleep(sk), &wait);
|
|
||||||
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
|
||||||
ret = sk_wait_event(sk, &timeo,
|
|
||||||
!list_empty(&psock->ingress_msg) ||
|
|
||||||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
|
|
||||||
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
|
|
||||||
remove_wait_queue(sk_sleep(sk), &wait);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
|
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
|
||||||
int nonblock, int flags, int *addr_len)
|
int nonblock, int flags, int *addr_len)
|
||||||
{
|
{
|
||||||
|
@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
|
||||||
}
|
}
|
||||||
lock_sock(sk);
|
lock_sock(sk);
|
||||||
msg_bytes_ready:
|
msg_bytes_ready:
|
||||||
copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
|
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
|
||||||
if (!copied) {
|
if (!copied) {
|
||||||
int data, err = 0;
|
int data, err = 0;
|
||||||
long timeo;
|
long timeo;
|
||||||
|
|
||||||
timeo = sock_rcvtimeo(sk, nonblock);
|
timeo = sock_rcvtimeo(sk, nonblock);
|
||||||
data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
|
data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
|
||||||
if (data) {
|
if (data) {
|
||||||
if (!sk_psock_queue_empty(psock))
|
if (!sk_psock_queue_empty(psock))
|
||||||
goto msg_bytes_ready;
|
goto msg_bytes_ready;
|
||||||
|
|
|
@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||||
skb = tls_wait_data(sk, psock, flags, timeo, &err);
|
skb = tls_wait_data(sk, psock, flags, timeo, &err);
|
||||||
if (!skb) {
|
if (!skb) {
|
||||||
if (psock) {
|
if (psock) {
|
||||||
int ret = __tcp_bpf_recvmsg(sk, psock,
|
int ret = sk_msg_recvmsg(sk, psock, msg, len,
|
||||||
msg, len, flags);
|
flags);
|
||||||
|
|
||||||
if (ret > 0) {
|
if (ret > 0) {
|
||||||
decrypted += ret;
|
decrypted += ret;
|
||||||
|
|
Loading…
Reference in New Issue