Merge branch 'MCTP-tag-control-interface'

Jeremy Kerr says:

====================
MCTP tag control interface

This series implements a small interface for userspace-controlled
message tag allocation for the MCTP protocol. Rather than leaving the
kernel to allocate per-message tag values, userspace can explicitly
allocate (and release) message tags through two new ioctls:
SIOCMCTPALLOCTAG and SIOCMCTPDROPTAG.

In order to do this, we first introduce some minor changes to the tag
handling, including a couple of new tests for the route input paths.

As always, any comments/queries/etc are most welcome.

v2:
 - make mctp_lookup_prealloc_tag static
 - minor checkpatch formatting fixes
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
David S. Miller 2022-02-09 12:00:11 +00:00
commit b4f029f4f4
7 changed files with 489 additions and 68 deletions

View File

@ -212,6 +212,54 @@ remote address is already known, or the message does not require a reply.
Like the send calls, sockets will only receive responses to requests they have Like the send calls, sockets will only receive responses to requests they have
sent (TO=1) and may only respond (TO=0) to requests they have received. sent (TO=1) and may only respond (TO=0) to requests they have received.
``ioctl(SIOCMCTPALLOCTAG)`` and ``ioctl(SIOCMCTPDROPTAG)``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
These tags give applications more control over MCTP message tags, by allocating
(and dropping) tag values explicitly, rather than the kernel automatically
allocating a per-message tag at ``sendmsg()`` time.
In general, you will only need to use these ioctls if your MCTP protocol does
not fit the usual request/response model. For example, if you need to persist
tags across multiple requests, or a request may generate more than one response.
In these cases, the ioctls allow you to decouple the tag allocation (and
release) from individual message send and receive operations.
Both ioctls are passed a pointer to a ``struct mctp_ioc_tag_ctl``:
.. code-block:: C
struct mctp_ioc_tag_ctl {
mctp_eid_t peer_addr;
__u8 tag;
__u16 flags;
};
``SIOCMCTPALLOCTAG`` allocates a tag for a specific peer, which an application
can use in future ``sendmsg()`` calls. The application populates the
``peer_addr`` member with the remote EID. Other fields must be zero.
On return, the ``tag`` member will be populated with the allocated tag value.
The allocated tag will have the following tag bits set:
- ``MCTP_TAG_OWNER``: it only makes sense to allocate tags if you're the tag
owner
- ``MCTP_TAG_PREALLOC``: to indicate to ``sendmsg()`` that this is a
preallocated tag.
- ... and the actual tag value, within the least-significant three bits
(``MCTP_TAG_MASK``). Note that zero is a valid tag value.
The tag value should be used as-is for the ``smctp_tag`` member of ``struct
sockaddr_mctp``.
``SIOCMCTPDROPTAG`` releases a tag that has been previously allocated by a
``SIOCMCTPALLOCTAG`` ioctl. The ``peer_addr`` must be the same as used for the
allocation, and the ``tag`` value must match exactly the tag returned from the
allocation (including the ``MCTP_TAG_OWNER`` and ``MCTP_TAG_PREALLOC`` bits).
The ``flags`` field must be zero.
Kernel internals Kernel internals
================ ================

View File

@ -45,6 +45,11 @@ static inline bool mctp_address_ok(mctp_eid_t eid)
return eid >= 8 && eid < 255; return eid >= 8 && eid < 255;
} }
static inline bool mctp_address_matches(mctp_eid_t match, mctp_eid_t eid)
{
return match == eid || match == MCTP_ADDR_ANY;
}
static inline struct mctp_hdr *mctp_hdr(struct sk_buff *skb) static inline struct mctp_hdr *mctp_hdr(struct sk_buff *skb)
{ {
return (struct mctp_hdr *)skb_network_header(skb); return (struct mctp_hdr *)skb_network_header(skb);
@ -121,7 +126,7 @@ struct mctp_sock {
*/ */
struct mctp_sk_key { struct mctp_sk_key {
mctp_eid_t peer_addr; mctp_eid_t peer_addr;
mctp_eid_t local_addr; mctp_eid_t local_addr; /* MCTP_ADDR_ANY for local owned tags */
__u8 tag; /* incoming tag match; invert TO for local */ __u8 tag; /* incoming tag match; invert TO for local */
/* we hold a ref to sk when set */ /* we hold a ref to sk when set */
@ -158,6 +163,12 @@ struct mctp_sk_key {
*/ */
unsigned long dev_flow_state; unsigned long dev_flow_state;
struct mctp_dev *dev; struct mctp_dev *dev;
/* a tag allocated with SIOCMCTPALLOCTAG ioctl will not expire
* automatically on timeout or response, instead SIOCMCTPDROPTAG
* is used.
*/
bool manual_alloc;
}; };
struct mctp_skb_cb { struct mctp_skb_cb {
@ -234,6 +245,9 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag); struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag);
void mctp_key_unref(struct mctp_sk_key *key); void mctp_key_unref(struct mctp_sk_key *key);
struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
mctp_eid_t daddr, mctp_eid_t saddr,
bool manual, u8 *tagp);
/* routing <--> device interface */ /* routing <--> device interface */
unsigned int mctp_default_net(struct net *net); unsigned int mctp_default_net(struct net *net);

View File

@ -15,6 +15,7 @@ enum {
MCTP_TRACE_KEY_REPLIED, MCTP_TRACE_KEY_REPLIED,
MCTP_TRACE_KEY_INVALIDATED, MCTP_TRACE_KEY_INVALIDATED,
MCTP_TRACE_KEY_CLOSED, MCTP_TRACE_KEY_CLOSED,
MCTP_TRACE_KEY_DROPPED,
}; };
#endif /* __TRACE_MCTP_ENUMS */ #endif /* __TRACE_MCTP_ENUMS */
@ -22,6 +23,7 @@ TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_TIMEOUT);
TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_REPLIED); TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_REPLIED);
TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_INVALIDATED); TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_INVALIDATED);
TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_CLOSED); TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_CLOSED);
TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_DROPPED);
TRACE_EVENT(mctp_key_acquire, TRACE_EVENT(mctp_key_acquire,
TP_PROTO(const struct mctp_sk_key *key), TP_PROTO(const struct mctp_sk_key *key),
@ -66,7 +68,8 @@ TRACE_EVENT(mctp_key_release,
{ MCTP_TRACE_KEY_TIMEOUT, "timeout" }, { MCTP_TRACE_KEY_TIMEOUT, "timeout" },
{ MCTP_TRACE_KEY_REPLIED, "replied" }, { MCTP_TRACE_KEY_REPLIED, "replied" },
{ MCTP_TRACE_KEY_INVALIDATED, "invalidated" }, { MCTP_TRACE_KEY_INVALIDATED, "invalidated" },
{ MCTP_TRACE_KEY_CLOSED, "closed" }) { MCTP_TRACE_KEY_CLOSED, "closed" },
{ MCTP_TRACE_KEY_DROPPED, "dropped" })
) )
); );

View File

@ -44,7 +44,25 @@ struct sockaddr_mctp_ext {
#define MCTP_TAG_MASK 0x07 #define MCTP_TAG_MASK 0x07
#define MCTP_TAG_OWNER 0x08 #define MCTP_TAG_OWNER 0x08
#define MCTP_TAG_PREALLOC 0x10
#define MCTP_OPT_ADDR_EXT 1 #define MCTP_OPT_ADDR_EXT 1
#define SIOCMCTPALLOCTAG (SIOCPROTOPRIVATE + 0)
#define SIOCMCTPDROPTAG (SIOCPROTOPRIVATE + 1)
struct mctp_ioc_tag_ctl {
mctp_eid_t peer_addr;
/* For SIOCMCTPALLOCTAG: must be passed as zero, kernel will
* populate with the allocated tag value. Returned tag value will
* always have TO and PREALLOC set.
*
* For SIOCMCTPDROPTAG: userspace provides tag value to drop, from
* a prior SIOCMCTPALLOCTAG call (and so must have TO and PREALLOC set).
*/
__u8 tag;
__u16 flags;
};
#endif /* __UAPI_MCTP_H */ #endif /* __UAPI_MCTP_H */

View File

@ -6,6 +6,7 @@
* Copyright (c) 2021 Google * Copyright (c) 2021 Google
*/ */
#include <linux/compat.h>
#include <linux/if_arp.h> #include <linux/if_arp.h>
#include <linux/net.h> #include <linux/net.h>
#include <linux/mctp.h> #include <linux/mctp.h>
@ -21,6 +22,8 @@
/* socket implementation */ /* socket implementation */
static void mctp_sk_expire_keys(struct timer_list *timer);
static int mctp_release(struct socket *sock) static int mctp_release(struct socket *sock)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
@ -99,13 +102,20 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
struct sk_buff *skb; struct sk_buff *skb;
if (addr) { if (addr) {
const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
MCTP_TAG_PREALLOC;
if (addrlen < sizeof(struct sockaddr_mctp)) if (addrlen < sizeof(struct sockaddr_mctp))
return -EINVAL; return -EINVAL;
if (addr->smctp_family != AF_MCTP) if (addr->smctp_family != AF_MCTP)
return -EINVAL; return -EINVAL;
if (!mctp_sockaddr_is_ok(addr)) if (!mctp_sockaddr_is_ok(addr))
return -EINVAL; return -EINVAL;
if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER)) if (addr->smctp_tag & ~tagbits)
return -EINVAL;
/* can't preallocate a non-owned tag */
if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
!(addr->smctp_tag & MCTP_TAG_OWNER))
return -EINVAL; return -EINVAL;
} else { } else {
@ -248,6 +258,32 @@ out_free:
return rc; return rc;
} }
/* We're done with the key; invalidate, stop reassembly, and remove from lists.
*/
static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
unsigned long flags, unsigned long reason)
__releases(&key->lock)
__must_hold(&net->mctp.keys_lock)
{
struct sk_buff *skb;
trace_mctp_key_release(key, reason);
skb = key->reasm_head;
key->reasm_head = NULL;
key->reasm_dead = true;
key->valid = false;
mctp_dev_release_key(key->dev, key);
spin_unlock_irqrestore(&key->lock, flags);
hlist_del(&key->hlist);
hlist_del(&key->sklist);
/* unref for the lists */
mctp_key_unref(key);
kfree_skb(skb);
}
static int mctp_setsockopt(struct socket *sock, int level, int optname, static int mctp_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen) sockptr_t optval, unsigned int optlen)
{ {
@ -293,6 +329,115 @@ static int mctp_getsockopt(struct socket *sock, int level, int optname,
return -EINVAL; return -EINVAL;
} }
static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
{
struct net *net = sock_net(&msk->sk);
struct mctp_sk_key *key = NULL;
struct mctp_ioc_tag_ctl ctl;
unsigned long flags;
u8 tag;
if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
return -EFAULT;
if (ctl.tag)
return -EINVAL;
if (ctl.flags)
return -EINVAL;
key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY,
true, &tag);
if (IS_ERR(key))
return PTR_ERR(key);
ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) {
spin_lock_irqsave(&key->lock, flags);
__mctp_key_remove(key, net, flags, MCTP_TRACE_KEY_DROPPED);
mctp_key_unref(key);
return -EFAULT;
}
mctp_key_unref(key);
return 0;
}
static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg)
{
struct net *net = sock_net(&msk->sk);
struct mctp_ioc_tag_ctl ctl;
unsigned long flags, fl2;
struct mctp_sk_key *key;
struct hlist_node *tmp;
int rc;
u8 tag;
if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
return -EFAULT;
if (ctl.flags)
return -EINVAL;
/* Must be a local tag, TO set, preallocated */
if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
return -EINVAL;
tag = ctl.tag & MCTP_TAG_MASK;
rc = -EINVAL;
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
/* we do an irqsave here, even though we know the irq state,
* so we have the flags to pass to __mctp_key_remove
*/
spin_lock_irqsave(&key->lock, fl2);
if (key->manual_alloc &&
ctl.peer_addr == key->peer_addr &&
tag == key->tag) {
__mctp_key_remove(key, net, fl2,
MCTP_TRACE_KEY_DROPPED);
rc = 0;
} else {
spin_unlock_irqrestore(&key->lock, fl2);
}
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
return rc;
}
static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
{
struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
switch (cmd) {
case SIOCMCTPALLOCTAG:
return mctp_ioctl_alloctag(msk, arg);
case SIOCMCTPDROPTAG:
return mctp_ioctl_droptag(msk, arg);
}
return -EINVAL;
}
#ifdef CONFIG_COMPAT
static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
unsigned long arg)
{
void __user *argp = compat_ptr(arg);
switch (cmd) {
/* These have compatible ptr layouts */
case SIOCMCTPALLOCTAG:
case SIOCMCTPDROPTAG:
return mctp_ioctl(sock, cmd, (unsigned long)argp);
}
return -ENOIOCTLCMD;
}
#endif
static const struct proto_ops mctp_dgram_ops = { static const struct proto_ops mctp_dgram_ops = {
.family = PF_MCTP, .family = PF_MCTP,
.release = mctp_release, .release = mctp_release,
@ -302,7 +447,7 @@ static const struct proto_ops mctp_dgram_ops = {
.accept = sock_no_accept, .accept = sock_no_accept,
.getname = sock_no_getname, .getname = sock_no_getname,
.poll = datagram_poll, .poll = datagram_poll,
.ioctl = sock_no_ioctl, .ioctl = mctp_ioctl,
.gettstamp = sock_gettstamp, .gettstamp = sock_gettstamp,
.listen = sock_no_listen, .listen = sock_no_listen,
.shutdown = sock_no_shutdown, .shutdown = sock_no_shutdown,
@ -312,6 +457,9 @@ static const struct proto_ops mctp_dgram_ops = {
.recvmsg = mctp_recvmsg, .recvmsg = mctp_recvmsg,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = sock_no_sendpage, .sendpage = sock_no_sendpage,
#ifdef CONFIG_COMPAT
.compat_ioctl = mctp_compat_ioctl,
#endif
}; };
static void mctp_sk_expire_keys(struct timer_list *timer) static void mctp_sk_expire_keys(struct timer_list *timer)
@ -319,7 +467,7 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
struct mctp_sock *msk = container_of(timer, struct mctp_sock, struct mctp_sock *msk = container_of(timer, struct mctp_sock,
key_expiry); key_expiry);
struct net *net = sock_net(&msk->sk); struct net *net = sock_net(&msk->sk);
unsigned long next_expiry, flags; unsigned long next_expiry, flags, fl2;
struct mctp_sk_key *key; struct mctp_sk_key *key;
struct hlist_node *tmp; struct hlist_node *tmp;
bool next_expiry_valid = false; bool next_expiry_valid = false;
@ -327,15 +475,16 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
spin_lock_irqsave(&net->mctp.keys_lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
spin_lock(&key->lock); /* don't expire. manual_alloc is immutable, no locking
* required.
*/
if (key->manual_alloc)
continue;
spin_lock_irqsave(&key->lock, fl2);
if (!time_after_eq(key->expiry, jiffies)) { if (!time_after_eq(key->expiry, jiffies)) {
trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT); __mctp_key_remove(key, net, fl2,
key->valid = false; MCTP_TRACE_KEY_TIMEOUT);
hlist_del_rcu(&key->hlist);
hlist_del_rcu(&key->sklist);
spin_unlock(&key->lock);
mctp_key_unref(key);
continue; continue;
} }
@ -346,7 +495,7 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
next_expiry = key->expiry; next_expiry = key->expiry;
next_expiry_valid = true; next_expiry_valid = true;
} }
spin_unlock(&key->lock); spin_unlock_irqrestore(&key->lock, fl2);
} }
spin_unlock_irqrestore(&net->mctp.keys_lock, flags); spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
@ -387,9 +536,9 @@ static void mctp_sk_unhash(struct sock *sk)
{ {
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
unsigned long flags, fl2;
struct mctp_sk_key *key; struct mctp_sk_key *key;
struct hlist_node *tmp; struct hlist_node *tmp;
unsigned long flags;
/* remove from any type-based binds */ /* remove from any type-based binds */
mutex_lock(&net->mctp.bind_lock); mutex_lock(&net->mctp.bind_lock);
@ -399,20 +548,8 @@ static void mctp_sk_unhash(struct sock *sk)
/* remove tag allocations */ /* remove tag allocations */
spin_lock_irqsave(&net->mctp.keys_lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
hlist_del(&key->sklist); spin_lock_irqsave(&key->lock, fl2);
hlist_del(&key->hlist); __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
spin_lock(&key->lock);
kfree_skb(key->reasm_head);
key->reasm_head = NULL;
key->reasm_dead = true;
key->valid = false;
spin_unlock(&key->lock);
/* key is no longer on the lookup lists, unref */
mctp_key_unref(key);
} }
spin_unlock_irqrestore(&net->mctp.keys_lock, flags); spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
} }

View File

@ -64,8 +64,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
if (msk->bind_type != type) if (msk->bind_type != type)
continue; continue;
if (msk->bind_addr != MCTP_ADDR_ANY && if (!mctp_address_matches(msk->bind_addr, mh->dest))
msk->bind_addr != mh->dest)
continue; continue;
return msk; return msk;
@ -77,7 +76,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local, static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
mctp_eid_t peer, u8 tag) mctp_eid_t peer, u8 tag)
{ {
if (key->local_addr != local) if (!mctp_address_matches(key->local_addr, local))
return false; return false;
if (key->peer_addr != peer) if (key->peer_addr != peer)
@ -204,29 +203,38 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
return rc; return rc;
} }
/* We're done with the key; unset valid and remove from lists. There may still /* Helper for mctp_route_input().
* be outstanding refs on the key though... * We're done with the key; unlock and unref the key.
* For the usual case of automatic expiry we remove the key from lists.
* In the case that manual allocation is set on a key we release the lock
* and local ref, reset reassembly, but don't remove from lists.
*/ */
static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net, static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net,
unsigned long flags) unsigned long flags, unsigned long reason)
__releases(&key->lock) __releases(&key->lock)
{ {
struct sk_buff *skb; struct sk_buff *skb;
trace_mctp_key_release(key, reason);
skb = key->reasm_head; skb = key->reasm_head;
key->reasm_head = NULL; key->reasm_head = NULL;
if (!key->manual_alloc) {
key->reasm_dead = true; key->reasm_dead = true;
key->valid = false; key->valid = false;
mctp_dev_release_key(key->dev, key); mctp_dev_release_key(key->dev, key);
}
spin_unlock_irqrestore(&key->lock, flags); spin_unlock_irqrestore(&key->lock, flags);
if (!key->manual_alloc) {
spin_lock_irqsave(&net->mctp.keys_lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_del(&key->hlist); hlist_del(&key->hlist);
hlist_del(&key->sklist); hlist_del(&key->sklist);
spin_unlock_irqrestore(&net->mctp.keys_lock, flags); spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
/* one unref for the lists */ /* unref for the lists */
mctp_key_unref(key); mctp_key_unref(key);
}
/* and one for the local reference */ /* and one for the local reference */
mctp_key_unref(key); mctp_key_unref(key);
@ -380,9 +388,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
/* we've hit a pending reassembly; not much we /* we've hit a pending reassembly; not much we
* can do but drop it * can do but drop it
*/ */
trace_mctp_key_release(key, __mctp_key_done_in(key, net, f,
MCTP_TRACE_KEY_REPLIED); MCTP_TRACE_KEY_REPLIED);
__mctp_key_unlock_drop(key, net, f);
key = NULL; key = NULL;
} }
rc = 0; rc = 0;
@ -424,9 +431,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
} else { } else {
if (key->reasm_head || key->reasm_dead) { if (key->reasm_head || key->reasm_dead) {
/* duplicate start? drop everything */ /* duplicate start? drop everything */
trace_mctp_key_release(key, __mctp_key_done_in(key, net, f,
MCTP_TRACE_KEY_INVALIDATED); MCTP_TRACE_KEY_INVALIDATED);
__mctp_key_unlock_drop(key, net, f);
rc = -EEXIST; rc = -EEXIST;
key = NULL; key = NULL;
} else { } else {
@ -449,10 +455,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* the reassembly/response key * the reassembly/response key
*/ */
if (!rc && flags & MCTP_HDR_FLAG_EOM) { if (!rc && flags & MCTP_HDR_FLAG_EOM) {
msk = container_of(key->sk, struct mctp_sock, sk);
sock_queue_rcv_skb(key->sk, key->reasm_head); sock_queue_rcv_skb(key->sk, key->reasm_head);
key->reasm_head = NULL; key->reasm_head = NULL;
trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED); __mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED);
__mctp_key_unlock_drop(key, net, f);
key = NULL; key = NULL;
} }
@ -580,9 +586,9 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
/* Allocate a locally-owned tag value for (saddr, daddr), and reserve /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
* it for the socket msk * it for the socket msk
*/ */
static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk, struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
mctp_eid_t saddr, mctp_eid_t daddr, mctp_eid_t saddr,
mctp_eid_t daddr, u8 *tagp) bool manual, u8 *tagp)
{ {
struct net *net = sock_net(&msk->sk); struct net *net = sock_net(&msk->sk);
struct netns_mctp *mns = &net->mctp; struct netns_mctp *mns = &net->mctp;
@ -616,9 +622,8 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
if (tmp->tag & MCTP_HDR_FLAG_TO) if (tmp->tag & MCTP_HDR_FLAG_TO)
continue; continue;
if (!((tmp->peer_addr == daddr || if (!(mctp_address_matches(tmp->peer_addr, daddr) &&
tmp->peer_addr == MCTP_ADDR_ANY) && mctp_address_matches(tmp->local_addr, saddr)))
tmp->local_addr == saddr))
continue; continue;
spin_lock(&tmp->lock); spin_lock(&tmp->lock);
@ -638,6 +643,7 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
mctp_reserve_tag(net, key, msk); mctp_reserve_tag(net, key, msk);
trace_mctp_key_acquire(key); trace_mctp_key_acquire(key);
key->manual_alloc = manual;
*tagp = key->tag; *tagp = key->tag;
} }
@ -651,6 +657,50 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
return key; return key;
} }
static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
mctp_eid_t daddr,
u8 req_tag, u8 *tagp)
{
struct net *net = sock_net(&msk->sk);
struct netns_mctp *mns = &net->mctp;
struct mctp_sk_key *key, *tmp;
unsigned long flags;
req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER);
key = NULL;
spin_lock_irqsave(&mns->keys_lock, flags);
hlist_for_each_entry(tmp, &mns->keys, hlist) {
if (tmp->tag != req_tag)
continue;
if (!mctp_address_matches(tmp->peer_addr, daddr))
continue;
if (!tmp->manual_alloc)
continue;
spin_lock(&tmp->lock);
if (tmp->valid) {
key = tmp;
refcount_inc(&key->refs);
spin_unlock(&tmp->lock);
break;
}
spin_unlock(&tmp->lock);
}
spin_unlock_irqrestore(&mns->keys_lock, flags);
if (!key)
return ERR_PTR(-ENOENT);
if (tagp)
*tagp = key->tag;
return key;
}
/* routing lookups */ /* routing lookups */
static bool mctp_rt_match_eid(struct mctp_route *rt, static bool mctp_rt_match_eid(struct mctp_route *rt,
unsigned int net, mctp_eid_t eid) unsigned int net, mctp_eid_t eid)
@ -845,8 +895,14 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
if (rc) if (rc)
goto out_release; goto out_release;
if (req_tag & MCTP_HDR_FLAG_TO) { if (req_tag & MCTP_TAG_OWNER) {
key = mctp_alloc_local_tag(msk, saddr, daddr, &tag); if (req_tag & MCTP_TAG_PREALLOC)
key = mctp_lookup_prealloc_tag(msk, daddr,
req_tag, &tag);
else
key = mctp_alloc_local_tag(msk, daddr, saddr,
false, &tag);
if (IS_ERR(key)) { if (IS_ERR(key)) {
rc = PTR_ERR(key); rc = PTR_ERR(key);
goto out_release; goto out_release;
@ -857,7 +913,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
tag |= MCTP_HDR_FLAG_TO; tag |= MCTP_HDR_FLAG_TO;
} else { } else {
key = NULL; key = NULL;
tag = req_tag; tag = req_tag & MCTP_TAG_MASK;
} }
skb->protocol = htons(ETH_P_MCTP); skb->protocol = htons(ETH_P_MCTP);

View File

@ -369,14 +369,15 @@ static void mctp_test_route_input_sk(struct kunit *test)
#define FL_S (MCTP_HDR_FLAG_SOM) #define FL_S (MCTP_HDR_FLAG_SOM)
#define FL_E (MCTP_HDR_FLAG_EOM) #define FL_E (MCTP_HDR_FLAG_EOM)
#define FL_T (MCTP_HDR_FLAG_TO) #define FL_TO (MCTP_HDR_FLAG_TO)
#define FL_T(t) ((t) & MCTP_HDR_TAG_MASK)
static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = { static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = {
{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 0, .deliver = true }, { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true },
{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 1, .deliver = false }, { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false }, { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, FL_E | FL_T), .type = 0, .deliver = false }, { .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, FL_T), .type = 0, .deliver = false }, { .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false }, { .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false },
}; };
@ -436,7 +437,7 @@ static void mctp_test_route_input_sk_reasm(struct kunit *test)
__mctp_route_test_fini(test, dev, rt, sock); __mctp_route_test_fini(test, dev, rt, sock);
} }
#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_T | (f) | ((s) << MCTP_HDR_SEQ_SHIFT)) #define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = { static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = {
{ {
@ -522,12 +523,156 @@ static void mctp_route_input_sk_reasm_to_desc(
KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests, KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests,
mctp_route_input_sk_reasm_to_desc); mctp_route_input_sk_reasm_to_desc);
struct mctp_route_input_sk_keys_test {
const char *name;
mctp_eid_t key_peer_addr;
mctp_eid_t key_local_addr;
u8 key_tag;
struct mctp_hdr hdr;
bool deliver;
};
/* test packet rx in the presence of various key configurations */
static void mctp_test_route_input_sk_keys(struct kunit *test)
{
const struct mctp_route_input_sk_keys_test *params;
struct mctp_test_route *rt;
struct sk_buff *skb, *skb2;
struct mctp_test_dev *dev;
struct mctp_sk_key *key;
struct netns_mctp *mns;
struct mctp_sock *msk;
struct socket *sock;
unsigned long flags;
int rc;
u8 c;
params = test->param_value;
dev = mctp_test_create_dev();
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
KUNIT_ASSERT_EQ(test, rc, 0);
msk = container_of(sock->sk, struct mctp_sock, sk);
mns = &sock_net(sock->sk)->mctp;
/* set the incoming tag according to test params */
key = mctp_key_alloc(msk, params->key_local_addr, params->key_peer_addr,
params->key_tag, GFP_KERNEL);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);
spin_lock_irqsave(&mns->keys_lock, flags);
mctp_reserve_tag(&init_net, key, msk);
spin_unlock_irqrestore(&mns->keys_lock, flags);
/* create packet and route */
c = 0;
skb = mctp_test_create_skb_data(&params->hdr, &c);
KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
skb->dev = dev->ndev;
__mctp_cb(skb);
rc = mctp_route_input(&rt->rt, skb);
/* (potentially) receive message */
skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
if (params->deliver)
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
else
KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
if (skb2)
skb_free_datagram(sock->sk, skb2);
mctp_key_unref(key);
__mctp_route_test_fini(test, dev, rt, sock);
}
static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = {
{
.name = "direct match",
.key_peer_addr = 9,
.key_local_addr = 8,
.key_tag = 1,
.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
.deliver = true,
},
{
.name = "flipped src/dest",
.key_peer_addr = 8,
.key_local_addr = 9,
.key_tag = 1,
.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
.deliver = false,
},
{
.name = "peer addr mismatch",
.key_peer_addr = 9,
.key_local_addr = 8,
.key_tag = 1,
.hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)),
.deliver = false,
},
{
.name = "tag value mismatch",
.key_peer_addr = 9,
.key_local_addr = 8,
.key_tag = 1,
.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)),
.deliver = false,
},
{
.name = "TO mismatch",
.key_peer_addr = 9,
.key_local_addr = 8,
.key_tag = 1,
.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO),
.deliver = false,
},
{
.name = "broadcast response",
.key_peer_addr = MCTP_ADDR_ANY,
.key_local_addr = 8,
.key_tag = 1,
.hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)),
.deliver = true,
},
{
.name = "any local match",
.key_peer_addr = 12,
.key_local_addr = MCTP_ADDR_ANY,
.key_tag = 1,
.hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)),
.deliver = true,
},
};
static void mctp_route_input_sk_keys_to_desc(
const struct mctp_route_input_sk_keys_test *t,
char *desc)
{
sprintf(desc, "%s", t->name);
}
KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests,
mctp_route_input_sk_keys_to_desc);
static struct kunit_case mctp_test_cases[] = { static struct kunit_case mctp_test_cases[] = {
KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params), KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params), KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params), KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params),
KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm, KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm,
mctp_route_input_sk_reasm_gen_params), mctp_route_input_sk_reasm_gen_params),
KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys,
mctp_route_input_sk_keys_gen_params),
{} {}
}; };