Merge branch 'mptcp-new-features-for-mptcp-sockets-and-netlink-pm'

Mat Martineau says:

====================
mptcp: New features for MPTCP sockets and netlink PM

This collection of patches adds MPTCP socket support for a few socket
options, ioctls, and one ancillary data type (specifics for each are
listed below). There's also a patch modifying the netlink MPTCP path
manager API to allow setting the backup flag on a configured interface
using the endpoint ID instead of the full IP address.

Patches 1 & 2: TCP_INQ cmsg and selftests.

Patches 2 & 3: SIOCINQ, OUTQ, and OUTQNSD ioctls and selftests.

Patch 5: Change backup flag using endpoint ID.

Patches 6 & 7: IP_TOS socket option and selftests.

Patches 8-10: TCP_CORK and TCP_NODELAY socket options. Includes a tcp
change to expose __tcp_sock_set_cork() and __tcp_sock_set_nodelay() for
use by MPTCP.
====================

Link: https://lore.kernel.org/r/20211203223541.69364-1-mathew.j.martineau@linux.intel.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jakub Kicinski 2021-12-07 11:36:36 -08:00
commit 59d58d93af
12 changed files with 1007 additions and 13 deletions

View File

@ -512,11 +512,13 @@ static inline u16 tcp_mss_clamp(const struct tcp_sock *tp, u16 mss)
int tcp_skb_shift(struct sk_buff *to, struct sk_buff *from, int pcount,
int shiftlen);
void __tcp_sock_set_cork(struct sock *sk, bool on);
void tcp_sock_set_cork(struct sock *sk, bool on);
int tcp_sock_set_keepcnt(struct sock *sk, int val);
int tcp_sock_set_keepidle_locked(struct sock *sk, int val);
int tcp_sock_set_keepidle(struct sock *sk, int val);
int tcp_sock_set_keepintvl(struct sock *sk, int val);
void __tcp_sock_set_nodelay(struct sock *sk, bool on);
void tcp_sock_set_nodelay(struct sock *sk);
void tcp_sock_set_quickack(struct sock *sk, int val);
int tcp_sock_set_syncnt(struct sock *sk, int val);

View File

@ -3207,7 +3207,7 @@ static void tcp_enable_tx_delay(void)
* TCP_CORK can be set together with TCP_NODELAY and it is stronger than
* TCP_NODELAY.
*/
static void __tcp_sock_set_cork(struct sock *sk, bool on)
void __tcp_sock_set_cork(struct sock *sk, bool on)
{
struct tcp_sock *tp = tcp_sk(sk);
@ -3235,7 +3235,7 @@ EXPORT_SYMBOL(tcp_sock_set_cork);
* However, when TCP_NODELAY is set we make an explicit push, which overrides
* even TCP_CORK for currently queued segments.
*/
static void __tcp_sock_set_nodelay(struct sock *sk, bool on)
void __tcp_sock_set_nodelay(struct sock *sk, bool on)
{
if (on) {
tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF|TCP_NAGLE_PUSH;

View File

@ -1702,22 +1702,28 @@ next:
static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
{
struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry;
struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
struct mptcp_pm_addr_entry addr, *entry;
struct net *net = sock_net(skb->sk);
u8 bkup = 0;
u8 bkup = 0, lookup_by_id = 0;
int ret;
ret = mptcp_pm_parse_addr(attr, info, true, &addr);
ret = mptcp_pm_parse_addr(attr, info, false, &addr);
if (ret < 0)
return ret;
if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
bkup = 1;
if (addr.addr.family == AF_UNSPEC) {
lookup_by_id = 1;
if (!addr.addr.id)
return -EOPNOTSUPP;
}
list_for_each_entry(entry, &pernet->local_addr_list, list) {
if (addresses_equal(&entry->addr, &addr.addr, true)) {
if ((!lookup_by_id && addresses_equal(&entry->addr, &addr.addr, true)) ||
(lookup_by_id && entry->addr.id == addr.addr.id)) {
mptcp_nl_addr_backup(net, &entry->addr, bkup);
if (bkup)

View File

@ -22,6 +22,7 @@
#endif
#include <net/mptcp.h>
#include <net/xfrm.h>
#include <asm/ioctls.h>
#include "protocol.h"
#include "mib.h"
@ -46,6 +47,7 @@ struct mptcp_skb_cb {
enum {
MPTCP_CMSG_TS = BIT(0),
MPTCP_CMSG_INQ = BIT(1),
};
static struct percpu_counter mptcp_sockets_allocated ____cacheline_aligned_in_smp;
@ -738,6 +740,7 @@ static bool __mptcp_ofo_queue(struct mptcp_sock *msk)
MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq,
delta);
MPTCP_SKB_CB(skb)->offset += delta;
MPTCP_SKB_CB(skb)->map_seq += delta;
__skb_queue_tail(&sk->sk_receive_queue, skb);
}
msk->ack_seq = end_seq;
@ -1499,7 +1502,7 @@ static void mptcp_update_post_push(struct mptcp_sock *msk,
msk->snd_nxt = snd_nxt_new;
}
static void mptcp_check_and_set_pending(struct sock *sk)
void mptcp_check_and_set_pending(struct sock *sk)
{
if (mptcp_send_head(sk) &&
!test_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->flags))
@ -1784,8 +1787,10 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
copied += count;
if (count < data_len) {
if (!(flags & MSG_PEEK))
if (!(flags & MSG_PEEK)) {
MPTCP_SKB_CB(skb)->offset += count;
MPTCP_SKB_CB(skb)->map_seq += count;
}
break;
}
@ -1965,6 +1970,27 @@ static bool __mptcp_move_skbs(struct mptcp_sock *msk)
return !skb_queue_empty(&msk->receive_queue);
}
static unsigned int mptcp_inq_hint(const struct sock *sk)
{
const struct mptcp_sock *msk = mptcp_sk(sk);
const struct sk_buff *skb;
skb = skb_peek(&msk->receive_queue);
if (skb) {
u64 hint_val = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq;
if (hint_val >= INT_MAX)
return INT_MAX;
return (unsigned int)hint_val;
}
if (sk->sk_state == TCP_CLOSE || (sk->sk_shutdown & RCV_SHUTDOWN))
return 1;
return 0;
}
static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len)
{
@ -1989,6 +2015,9 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
len = min_t(size_t, len, INT_MAX);
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
if (unlikely(msk->recvmsg_inq))
cmsg_flags = MPTCP_CMSG_INQ;
while (copied < len) {
int bytes_read;
@ -2062,6 +2091,12 @@ out_err:
if (cmsg_flags && copied >= 0) {
if (cmsg_flags & MPTCP_CMSG_TS)
tcp_recv_timestamp(msg, sk, &tss);
if (cmsg_flags & MPTCP_CMSG_INQ) {
unsigned int inq = mptcp_inq_hint(sk);
put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq);
}
}
pr_debug("msk=%p rx queue empty=%d:%d copied=%d",
@ -3177,6 +3212,57 @@ static int mptcp_forward_alloc_get(const struct sock *sk)
return sk->sk_forward_alloc + mptcp_sk(sk)->rmem_fwd_alloc;
}
static int mptcp_ioctl_outq(const struct mptcp_sock *msk, u64 v)
{
const struct sock *sk = (void *)msk;
u64 delta;
if (sk->sk_state == TCP_LISTEN)
return -EINVAL;
if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV))
return 0;
delta = msk->write_seq - v;
if (delta > INT_MAX)
delta = INT_MAX;
return (int)delta;
}
static int mptcp_ioctl(struct sock *sk, int cmd, unsigned long arg)
{
struct mptcp_sock *msk = mptcp_sk(sk);
bool slow;
int answ;
switch (cmd) {
case SIOCINQ:
if (sk->sk_state == TCP_LISTEN)
return -EINVAL;
lock_sock(sk);
__mptcp_move_skbs(msk);
answ = mptcp_inq_hint(sk);
release_sock(sk);
break;
case SIOCOUTQ:
slow = lock_sock_fast(sk);
answ = mptcp_ioctl_outq(msk, READ_ONCE(msk->snd_una));
unlock_sock_fast(sk, slow);
break;
case SIOCOUTQNSD:
slow = lock_sock_fast(sk);
answ = mptcp_ioctl_outq(msk, msk->snd_nxt);
unlock_sock_fast(sk, slow);
break;
default:
return -ENOIOCTLCMD;
}
return put_user(answ, (int __user *)arg);
}
static struct proto mptcp_prot = {
.name = "MPTCP",
.owner = THIS_MODULE,
@ -3189,6 +3275,7 @@ static struct proto mptcp_prot = {
.shutdown = mptcp_shutdown,
.destroy = mptcp_destroy,
.sendmsg = mptcp_sendmsg,
.ioctl = mptcp_ioctl,
.recvmsg = mptcp_recvmsg,
.release_cb = mptcp_release_cb,
.hash = mptcp_hash,

View File

@ -249,6 +249,9 @@ struct mptcp_sock {
bool rcv_fastclose;
bool use_64bit_ack; /* Set when we received a 64-bit DSN */
bool csum_enabled;
u8 recvmsg_inq:1,
cork:1,
nodelay:1;
spinlock_t join_list_lock;
struct work_struct work;
struct sk_buff *ooo_last_skb;
@ -554,6 +557,7 @@ unsigned int mptcp_stale_loss_cnt(const struct net *net);
void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow,
struct mptcp_options_received *mp_opt);
bool __mptcp_retransmit_pending_data(struct sock *sk);
void mptcp_check_and_set_pending(struct sock *sk);
void __mptcp_push_pending(struct sock *sk, unsigned int flags);
bool mptcp_subflow_data_available(struct sock *sk);
void __init mptcp_subflow_init(void);

View File

@ -557,6 +557,7 @@ static bool mptcp_supported_sockopt(int level, int optname)
case TCP_TIMESTAMP:
case TCP_NOTSENT_LOWAT:
case TCP_TX_DELAY:
case TCP_INQ:
return true;
}
@ -568,7 +569,6 @@ static bool mptcp_supported_sockopt(int level, int optname)
/* TCP_FASTOPEN_KEY, TCP_FASTOPEN TCP_FASTOPEN_CONNECT, TCP_FASTOPEN_NO_COOKIE,
* are not supported fastopen is currently unsupported
*/
/* TCP_INQ is currently unsupported, needs some recvmsg work */
}
return false;
}
@ -616,6 +616,66 @@ static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t
return ret;
}
static int mptcp_setsockopt_sol_tcp_cork(struct mptcp_sock *msk, sockptr_t optval,
unsigned int optlen)
{
struct mptcp_subflow_context *subflow;
struct sock *sk = (struct sock *)msk;
int val;
if (optlen < sizeof(int))
return -EINVAL;
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
lock_sock(sk);
sockopt_seq_inc(msk);
msk->cork = !!val;
mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
lock_sock(ssk);
__tcp_sock_set_cork(ssk, !!val);
release_sock(ssk);
}
if (!val)
mptcp_check_and_set_pending(sk);
release_sock(sk);
return 0;
}
static int mptcp_setsockopt_sol_tcp_nodelay(struct mptcp_sock *msk, sockptr_t optval,
unsigned int optlen)
{
struct mptcp_subflow_context *subflow;
struct sock *sk = (struct sock *)msk;
int val;
if (optlen < sizeof(int))
return -EINVAL;
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
lock_sock(sk);
sockopt_seq_inc(msk);
msk->nodelay = !!val;
mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
lock_sock(ssk);
__tcp_sock_set_nodelay(ssk, !!val);
release_sock(ssk);
}
if (val)
mptcp_check_and_set_pending(sk);
release_sock(sk);
return 0;
}
static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int optname,
sockptr_t optval, unsigned int optlen)
{
@ -698,11 +758,29 @@ static int mptcp_setsockopt_v4(struct mptcp_sock *msk, int optname,
static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname,
sockptr_t optval, unsigned int optlen)
{
struct sock *sk = (void *)msk;
int ret, val;
switch (optname) {
case TCP_INQ:
ret = mptcp_get_int_option(msk, optval, optlen, &val);
if (ret)
return ret;
if (val < 0 || val > 1)
return -EINVAL;
lock_sock(sk);
msk->recvmsg_inq = !!val;
release_sock(sk);
return 0;
case TCP_ULP:
return -EOPNOTSUPP;
case TCP_CONGESTION:
return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen);
case TCP_CORK:
return mptcp_setsockopt_sol_tcp_cork(msk, optval, optlen);
case TCP_NODELAY:
return mptcp_setsockopt_sol_tcp_nodelay(msk, optval, optlen);
}
return -EOPNOTSUPP;
@ -1032,6 +1110,35 @@ static int mptcp_getsockopt_subflow_addrs(struct mptcp_sock *msk, char __user *o
return 0;
}
static int mptcp_put_int_option(struct mptcp_sock *msk, char __user *optval,
int __user *optlen, int val)
{
int len;
if (get_user(len, optlen))
return -EFAULT;
if (len < 0)
return -EINVAL;
if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) {
unsigned char ucval = (unsigned char)val;
len = 1;
if (put_user(len, optlen))
return -EFAULT;
if (copy_to_user(optval, &ucval, 1))
return -EFAULT;
} else {
len = min_t(unsigned int, len, sizeof(int));
if (put_user(len, optlen))
return -EFAULT;
if (copy_to_user(optval, &val, len))
return -EFAULT;
}
return 0;
}
static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname,
char __user *optval, int __user *optlen)
{
@ -1042,10 +1149,29 @@ static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname,
case TCP_CC_INFO:
return mptcp_getsockopt_first_sf_only(msk, SOL_TCP, optname,
optval, optlen);
case TCP_INQ:
return mptcp_put_int_option(msk, optval, optlen, msk->recvmsg_inq);
case TCP_CORK:
return mptcp_put_int_option(msk, optval, optlen, msk->cork);
case TCP_NODELAY:
return mptcp_put_int_option(msk, optval, optlen, msk->nodelay);
}
return -EOPNOTSUPP;
}
static int mptcp_getsockopt_v4(struct mptcp_sock *msk, int optname,
char __user *optval, int __user *optlen)
{
struct sock *sk = (void *)msk;
switch (optname) {
case IP_TOS:
return mptcp_put_int_option(msk, optval, optlen, inet_sk(sk)->tos);
}
return -EOPNOTSUPP;
}
static int mptcp_getsockopt_sol_mptcp(struct mptcp_sock *msk, int optname,
char __user *optval, int __user *optlen)
{
@ -1081,6 +1207,8 @@ int mptcp_getsockopt(struct sock *sk, int level, int optname,
if (ssk)
return tcp_getsockopt(ssk, level, optname, optval, option);
if (level == SOL_IP)
return mptcp_getsockopt_v4(msk, optname, optval, option);
if (level == SOL_TCP)
return mptcp_getsockopt_sol_tcp(msk, optname, optval, option);
if (level == SOL_MPTCP)
@ -1129,6 +1257,8 @@ static void sync_socket_options(struct mptcp_sock *msk, struct sock *ssk)
if (inet_csk(sk)->icsk_ca_ops != inet_csk(ssk)->icsk_ca_ops)
tcp_set_congestion_control(ssk, msk->ca_name, false, true);
__tcp_sock_set_cork(ssk, !!msk->cork);
__tcp_sock_set_nodelay(ssk, !!msk->nodelay);
inet_sk(ssk)->transparent = inet_sk(sk)->transparent;
inet_sk(ssk)->freebind = inet_sk(sk)->freebind;

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: GPL-2.0-only
mptcp_connect
mptcp_inq
mptcp_sockopt
pm_nl_ctl
*.pcap

View File

@ -8,7 +8,7 @@ CFLAGS = -Wall -Wl,--no-as-needed -O2 -g -I$(top_srcdir)/usr/include
TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh diag.sh \
simult_flows.sh mptcp_sockopt.sh
TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt
TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt mptcp_inq
TEST_FILES := settings

View File

@ -73,12 +73,20 @@ static uint32_t cfg_mark;
struct cfg_cmsg_types {
unsigned int cmsg_enabled:1;
unsigned int timestampns:1;
unsigned int tcp_inq:1;
};
struct cfg_sockopt_types {
unsigned int transparent:1;
};
struct tcp_inq_state {
unsigned int last;
bool expect_eof;
};
static struct tcp_inq_state tcp_inq;
static struct cfg_cmsg_types cfg_cmsg_types;
static struct cfg_sockopt_types cfg_sockopt_types;
@ -389,7 +397,9 @@ static size_t do_write(const int fd, char *buf, const size_t len)
static void process_cmsg(struct msghdr *msgh)
{
struct __kernel_timespec ts;
bool inq_found = false;
bool ts_found = false;
unsigned int inq = 0;
struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
@ -398,12 +408,27 @@ static void process_cmsg(struct msghdr *msgh)
ts_found = true;
continue;
}
if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
memcpy(&inq, CMSG_DATA(cmsg), sizeof(inq));
inq_found = true;
continue;
}
}
if (cfg_cmsg_types.timestampns) {
if (!ts_found)
xerror("TIMESTAMPNS not present\n");
}
if (cfg_cmsg_types.tcp_inq) {
if (!inq_found)
xerror("TCP_INQ not present\n");
if (inq > 1024)
xerror("tcp_inq %u is larger than one kbyte\n", inq);
tcp_inq.last = inq;
}
}
static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
@ -420,10 +445,23 @@ static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
.msg_controllen = sizeof(msg_buf),
};
int flags = 0;
unsigned int last_hint = tcp_inq.last;
int ret = recvmsg(fd, &msg, flags);
if (ret <= 0)
if (ret <= 0) {
if (ret == 0 && tcp_inq.expect_eof)
return ret;
if (ret == 0 && cfg_cmsg_types.tcp_inq)
if (last_hint != 1 && last_hint != 0)
xerror("EOF but last tcp_inq hint was %u\n", last_hint);
return ret;
}
if (tcp_inq.expect_eof)
xerror("expected EOF, last_hint %u, now %u\n",
last_hint, tcp_inq.last);
if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled)
xerror("got %lu bytes of cmsg data, expected 0\n",
@ -435,6 +473,19 @@ static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
if (msg.msg_controllen)
process_cmsg(&msg);
if (cfg_cmsg_types.tcp_inq) {
if ((size_t)ret < len && last_hint > (unsigned int)ret) {
if (ret + 1 != (int)last_hint) {
int next = read(fd, msg_buf, sizeof(msg_buf));
xerror("read %u of %u, last_hint was %u tcp_inq hint now %u next_read returned %d/%m\n",
ret, (unsigned int)len, last_hint, tcp_inq.last, next);
} else {
tcp_inq.expect_eof = true;
}
}
}
return ret;
}
@ -944,6 +995,8 @@ static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg)
if (cmsg->timestampns)
xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on));
if (cmsg->tcp_inq)
xsetsockopt(fd, IPPROTO_TCP, TCP_INQ, &on, sizeof(on));
}
static void parse_cmsg_types(const char *type)
@ -965,6 +1018,11 @@ static void parse_cmsg_types(const char *type)
return;
}
if (strncmp(type, "TCPINQ", len) == 0) {
cfg_cmsg_types.tcp_inq = 1;
return;
}
fprintf(stderr, "Unrecognized cmsg option %s\n", type);
exit(1);
}

View File

@ -0,0 +1,603 @@
// SPDX-License-Identifier: GPL-2.0
#define _GNU_SOURCE
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <string.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <unistd.h>
#include <time.h>
#include <sys/ioctl.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <netdb.h>
#include <netinet/in.h>
#include <linux/tcp.h>
#include <linux/sockios.h>
#ifndef IPPROTO_MPTCP
#define IPPROTO_MPTCP 262
#endif
#ifndef SOL_MPTCP
#define SOL_MPTCP 284
#endif
static int pf = AF_INET;
static int proto_tx = IPPROTO_MPTCP;
static int proto_rx = IPPROTO_MPTCP;
static void die_perror(const char *msg)
{
perror(msg);
exit(1);
}
static void die_usage(int r)
{
fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
exit(r);
}
static void xerror(const char *fmt, ...)
{
va_list ap;
va_start(ap, fmt);
vfprintf(stderr, fmt, ap);
va_end(ap);
fputc('\n', stderr);
exit(1);
}
static const char *getxinfo_strerr(int err)
{
if (err == EAI_SYSTEM)
return strerror(errno);
return gai_strerror(err);
}
static void xgetaddrinfo(const char *node, const char *service,
const struct addrinfo *hints,
struct addrinfo **res)
{
int err = getaddrinfo(node, service, hints, res);
if (err) {
const char *errstr = getxinfo_strerr(err);
fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
node ? node : "", service ? service : "", errstr);
exit(1);
}
}
static int sock_listen_mptcp(const char * const listenaddr,
const char * const port)
{
int sock;
struct addrinfo hints = {
.ai_protocol = IPPROTO_TCP,
.ai_socktype = SOCK_STREAM,
.ai_flags = AI_PASSIVE | AI_NUMERICHOST
};
hints.ai_family = pf;
struct addrinfo *a, *addr;
int one = 1;
xgetaddrinfo(listenaddr, port, &hints, &addr);
hints.ai_family = pf;
for (a = addr; a; a = a->ai_next) {
sock = socket(a->ai_family, a->ai_socktype, proto_rx);
if (sock < 0)
continue;
if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
sizeof(one)))
perror("setsockopt");
if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
break; /* success */
perror("bind");
close(sock);
sock = -1;
}
freeaddrinfo(addr);
if (sock < 0)
xerror("could not create listen socket");
if (listen(sock, 20))
die_perror("listen");
return sock;
}
static int sock_connect_mptcp(const char * const remoteaddr,
const char * const port, int proto)
{
struct addrinfo hints = {
.ai_protocol = IPPROTO_TCP,
.ai_socktype = SOCK_STREAM,
};
struct addrinfo *a, *addr;
int sock = -1;
hints.ai_family = pf;
xgetaddrinfo(remoteaddr, port, &hints, &addr);
for (a = addr; a; a = a->ai_next) {
sock = socket(a->ai_family, a->ai_socktype, proto);
if (sock < 0)
continue;
if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
break; /* success */
die_perror("connect");
}
if (sock < 0)
xerror("could not create connect socket");
freeaddrinfo(addr);
return sock;
}
static int protostr_to_num(const char *s)
{
if (strcasecmp(s, "tcp") == 0)
return IPPROTO_TCP;
if (strcasecmp(s, "mptcp") == 0)
return IPPROTO_MPTCP;
die_usage(1);
return 0;
}
static void parse_opts(int argc, char **argv)
{
int c;
while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
switch (c) {
case 'h':
die_usage(0);
break;
case '6':
pf = AF_INET6;
break;
case 't':
proto_tx = protostr_to_num(optarg);
break;
case 'r':
proto_rx = protostr_to_num(optarg);
break;
default:
die_usage(1);
break;
}
}
}
/* wait up to timeout milliseconds */
static void wait_for_ack(int fd, int timeout, size_t total)
{
int i;
for (i = 0; i < timeout; i++) {
int nsd, ret, queued = -1;
struct timespec req;
ret = ioctl(fd, TIOCOUTQ, &queued);
if (ret < 0)
die_perror("TIOCOUTQ");
ret = ioctl(fd, SIOCOUTQNSD, &nsd);
if (ret < 0)
die_perror("SIOCOUTQNSD");
if ((size_t)queued > total)
xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
assert(nsd <= queued);
if (queued == 0)
return;
/* wait for peer to ack rx of all data */
req.tv_sec = 0;
req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
nanosleep(&req, NULL);
}
xerror("still tx data queued after %u ms\n", timeout);
}
static void connect_one_server(int fd, int unixfd)
{
size_t len, i, total, sent;
char buf[4096], buf2[4096];
ssize_t ret;
len = rand() % (sizeof(buf) - 1);
if (len < 128)
len = 128;
for (i = 0; i < len ; i++) {
buf[i] = rand() % 26;
buf[i] += 'A';
}
buf[i] = '\n';
/* un-block server */
ret = read(unixfd, buf2, 4);
assert(ret == 4);
assert(strncmp(buf2, "xmit", 4) == 0);
ret = write(unixfd, &len, sizeof(len));
assert(ret == (ssize_t)sizeof(len));
ret = write(fd, buf, len);
if (ret < 0)
die_perror("write");
if (ret != (ssize_t)len)
xerror("short write");
ret = read(unixfd, buf2, 4);
assert(strncmp(buf2, "huge", 4) == 0);
total = rand() % (16 * 1024 * 1024);
total += (1 * 1024 * 1024);
sent = total;
ret = write(unixfd, &total, sizeof(total));
assert(ret == (ssize_t)sizeof(total));
wait_for_ack(fd, 5000, len);
while (total > 0) {
if (total > sizeof(buf))
len = sizeof(buf);
else
len = total;
ret = write(fd, buf, len);
if (ret < 0)
die_perror("write");
total -= ret;
/* we don't have to care about buf content, only
* number of total bytes sent
*/
}
ret = read(unixfd, buf2, 4);
assert(ret == 4);
assert(strncmp(buf2, "shut", 4) == 0);
wait_for_ack(fd, 5000, sent);
ret = write(fd, buf, 1);
assert(ret == 1);
close(fd);
ret = write(unixfd, "closed", 6);
assert(ret == 6);
close(unixfd);
}
static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
{
struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
return;
}
}
xerror("could not find TCP_CM_INQ cmsg type");
}
static void process_one_client(int fd, int unixfd)
{
unsigned int tcp_inq;
size_t expect_len;
char msg_buf[4096];
char buf[4096];
char tmp[16];
struct iovec iov = {
.iov_base = buf,
.iov_len = 1,
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = msg_buf,
.msg_controllen = sizeof(msg_buf),
};
ssize_t ret, tot;
ret = write(unixfd, "xmit", 4);
assert(ret == 4);
ret = read(unixfd, &expect_len, sizeof(expect_len));
assert(ret == (ssize_t)sizeof(expect_len));
if (expect_len > sizeof(buf))
xerror("expect len %zu exceeds buffer size", expect_len);
for (;;) {
struct timespec req;
unsigned int queued;
ret = ioctl(fd, FIONREAD, &queued);
if (ret < 0)
die_perror("FIONREAD");
if (queued > expect_len)
xerror("FIONREAD returned %u, but only %zu expected\n",
queued, expect_len);
if (queued == expect_len)
break;
req.tv_sec = 0;
req.tv_nsec = 1000 * 1000ul;
nanosleep(&req, NULL);
}
/* read one byte, expect cmsg to return expected - 1 */
ret = recvmsg(fd, &msg, 0);
if (ret < 0)
die_perror("recvmsg");
if (msg.msg_controllen == 0)
xerror("msg_controllen is 0");
get_tcp_inq(&msg, &tcp_inq);
assert((size_t)tcp_inq == (expect_len - 1));
iov.iov_len = sizeof(buf);
ret = recvmsg(fd, &msg, 0);
if (ret < 0)
die_perror("recvmsg");
/* should have gotten exact remainder of all pending data */
assert(ret == (ssize_t)tcp_inq);
/* should be 0, all drained */
get_tcp_inq(&msg, &tcp_inq);
assert(tcp_inq == 0);
/* request a large swath of data. */
ret = write(unixfd, "huge", 4);
assert(ret == 4);
ret = read(unixfd, &expect_len, sizeof(expect_len));
assert(ret == (ssize_t)sizeof(expect_len));
/* peer should send us a few mb of data */
if (expect_len <= sizeof(buf))
xerror("expect len %zu too small\n", expect_len);
tot = 0;
do {
iov.iov_len = sizeof(buf);
ret = recvmsg(fd, &msg, 0);
if (ret < 0)
die_perror("recvmsg");
tot += ret;
get_tcp_inq(&msg, &tcp_inq);
if (tcp_inq > expect_len - tot)
xerror("inq %d, remaining %d total_len %d\n",
tcp_inq, expect_len - tot, (int)expect_len);
assert(tcp_inq <= expect_len - tot);
} while ((size_t)tot < expect_len);
ret = write(unixfd, "shut", 4);
assert(ret == 4);
/* wait for hangup. Should have received one more byte of data. */
ret = read(unixfd, tmp, sizeof(tmp));
assert(ret == 6);
assert(strncmp(tmp, "closed", 6) == 0);
sleep(1);
iov.iov_len = 1;
ret = recvmsg(fd, &msg, 0);
if (ret < 0)
die_perror("recvmsg");
assert(ret == 1);
get_tcp_inq(&msg, &tcp_inq);
/* tcp_inq should be 1 due to received fin. */
assert(tcp_inq == 1);
iov.iov_len = 1;
ret = recvmsg(fd, &msg, 0);
if (ret < 0)
die_perror("recvmsg");
/* expect EOF */
assert(ret == 0);
get_tcp_inq(&msg, &tcp_inq);
assert(tcp_inq == 1);
close(fd);
}
static int xaccept(int s)
{
int fd = accept(s, NULL, 0);
if (fd < 0)
die_perror("accept");
return fd;
}
static int server(int unixfd)
{
int fd = -1, r, on = 1;
switch (pf) {
case AF_INET:
fd = sock_listen_mptcp("127.0.0.1", "15432");
break;
case AF_INET6:
fd = sock_listen_mptcp("::1", "15432");
break;
default:
xerror("Unknown pf %d\n", pf);
break;
}
r = write(unixfd, "conn", 4);
assert(r == 4);
alarm(15);
r = xaccept(fd);
if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
die_perror("setsockopt");
process_one_client(r, unixfd);
return 0;
}
static int client(int unixfd)
{
int fd = -1;
alarm(15);
switch (pf) {
case AF_INET:
fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
break;
case AF_INET6:
fd = sock_connect_mptcp("::1", "15432", proto_tx);
break;
default:
xerror("Unknown pf %d\n", pf);
}
connect_one_server(fd, unixfd);
return 0;
}
static void init_rng(void)
{
int fd = open("/dev/urandom", O_RDONLY);
unsigned int foo;
if (fd > 0) {
int ret = read(fd, &foo, sizeof(foo));
if (ret < 0)
srand(fd + foo);
close(fd);
}
srand(foo);
}
static pid_t xfork(void)
{
pid_t p = fork();
if (p < 0)
die_perror("fork");
else if (p == 0)
init_rng();
return p;
}
static int rcheck(int wstatus, const char *what)
{
if (WIFEXITED(wstatus)) {
if (WEXITSTATUS(wstatus) == 0)
return 0;
fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
return WEXITSTATUS(wstatus);
} else if (WIFSIGNALED(wstatus)) {
xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
} else if (WIFSTOPPED(wstatus)) {
xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
}
return 111;
}
int main(int argc, char *argv[])
{
int e1, e2, wstatus;
pid_t s, c, ret;
int unixfds[2];
parse_opts(argc, argv);
e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
if (e1 < 0)
die_perror("pipe");
s = xfork();
if (s == 0)
return server(unixfds[1]);
close(unixfds[1]);
/* wait until server bound a socket */
e1 = read(unixfds[0], &e1, 4);
assert(e1 == 4);
c = xfork();
if (c == 0)
return client(unixfds[0]);
close(unixfds[0]);
ret = waitpid(s, &wstatus, 0);
if (ret == -1)
die_perror("waitpid");
e1 = rcheck(wstatus, "server");
ret = waitpid(c, &wstatus, 0);
if (ret == -1)
die_perror("waitpid");
e2 = rcheck(wstatus, "client");
return e1 ? e1 : e2;
}

View File

@ -4,6 +4,7 @@
#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <string.h>
#include <stdarg.h>
@ -13,6 +14,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <time.h>
#include <unistd.h>
#include <sys/socket.h>
@ -594,6 +596,44 @@ static int server(int pipefd)
return 0;
}
static void test_ip_tos_sockopt(int fd)
{
uint8_t tos_in, tos_out;
socklen_t s;
int r;
tos_in = rand() & 0xfc;
r = setsockopt(fd, SOL_IP, IP_TOS, &tos_in, sizeof(tos_out));
if (r != 0)
die_perror("setsockopt IP_TOS");
tos_out = 0;
s = sizeof(tos_out);
r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
if (r != 0)
die_perror("getsockopt IP_TOS");
if (tos_in != tos_out)
xerror("tos %x != %x socklen_t %d\n", tos_in, tos_out, s);
if (s != 1)
xerror("tos should be 1 byte");
s = 0;
r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
if (r != 0)
die_perror("getsockopt IP_TOS 0");
if (s != 0)
xerror("expect socklen_t == 0");
s = -1;
r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
if (r != -1 && errno != EINVAL)
die_perror("getsockopt IP_TOS did not indicate -EINVAL");
if (s != -1)
xerror("expect socklen_t == -1");
}
static int client(int pipefd)
{
int fd = -1;
@ -611,6 +651,8 @@ static int client(int pipefd)
xerror("Unknown pf %d\n", pf);
}
test_ip_tos_sockopt(fd);
connect_one_server(fd, pipefd);
return 0;
@ -642,6 +684,25 @@ static int rcheck(int wstatus, const char *what)
return 111;
}
static void init_rng(void)
{
int fd = open("/dev/urandom", O_RDONLY);
if (fd >= 0) {
unsigned int foo;
ssize_t ret;
/* can't fail */
ret = read(fd, &foo, sizeof(foo));
assert(ret == sizeof(foo));
close(fd);
srand(foo);
} else {
srand(time(NULL));
}
}
int main(int argc, char *argv[])
{
int e1, e2, wstatus;
@ -650,6 +711,8 @@ int main(int argc, char *argv[])
parse_opts(argc, argv);
init_rng();
e1 = pipe(pipefds);
if (e1 < 0)
die_perror("pipe");

View File

@ -178,7 +178,7 @@ do_transfer()
timeout ${timeout_test} \
ip netns exec ${listener_ns} \
$mptcp_connect -t ${timeout_poll} -l -M 1 -p $port -s ${srv_proto} -c TIMESTAMPNS \
$mptcp_connect -t ${timeout_poll} -l -M 1 -p $port -s ${srv_proto} -c TIMESTAMPNS,TCPINQ \
${local_addr} < "$sin" > "$sout" &
spid=$!
@ -186,7 +186,7 @@ do_transfer()
timeout ${timeout_test} \
ip netns exec ${connector_ns} \
$mptcp_connect -t ${timeout_poll} -M 2 -p $port -s ${cl_proto} -c TIMESTAMPNS \
$mptcp_connect -t ${timeout_poll} -M 2 -p $port -s ${cl_proto} -c TIMESTAMPNS,TCPINQ \
$connect_addr < "$cin" > "$cout" &
cpid=$!
@ -279,6 +279,45 @@ run_tests()
fi
}
do_tcpinq_test()
{
ip netns exec "$ns1" ./mptcp_inq "$@"
lret=$?
if [ $lret -ne 0 ];then
ret=$lret
echo "FAIL: mptcp_inq $@" 1>&2
return $lret
fi
echo "PASS: TCP_INQ cmsg/ioctl $@"
return $lret
}
do_tcpinq_tests()
{
local lret=0
ip netns exec "$ns1" iptables -F
ip netns exec "$ns1" ip6tables -F
for args in "-t tcp" "-r tcp"; do
do_tcpinq_test $args
lret=$?
if [ $lret -ne 0 ] ; then
return $lret
fi
do_tcpinq_test -6 $args
lret=$?
if [ $lret -ne 0 ] ; then
return $lret
fi
done
do_tcpinq_test -r tcp -t tcp
return $?
}
sin=$(mktemp)
sout=$(mktemp)
cin=$(mktemp)
@ -300,4 +339,5 @@ if [ $ret -eq 0 ];then
echo "PASS: SOL_MPTCP getsockopt has expected information"
fi
do_tcpinq_tests
exit $ret