diff --git a/Documentation/networking/mctp.rst b/Documentation/networking/mctp.rst index 6100cdc220f6..2c54b029f990 100644 --- a/Documentation/networking/mctp.rst +++ b/Documentation/networking/mctp.rst @@ -211,3 +211,62 @@ remote address is already known, or the message does not require a reply. Like the send calls, sockets will only receive responses to requests they have sent (TO=1) and may only respond (TO=0) to requests they have received. + +Kernel internals +================ + +There are a few possible packet flows in the MCTP stack: + +1. local TX to remote endpoint, message <= MTU:: + + sendmsg() + -> mctp_local_output() + : route lookup + -> rt->output() (== mctp_route_output) + -> dev_queue_xmit() + +2. local TX to remote endpoint, message > MTU:: + + sendmsg() + -> mctp_local_output() + -> mctp_do_fragment_route() + : creates packet-sized skbs. For each new skb: + -> rt->output() (== mctp_route_output) + -> dev_queue_xmit() + +3. remote TX to local endpoint, single-packet message:: + + mctp_pkttype_receive() + : route lookup + -> rt->output() (== mctp_route_input) + : sk_key lookup + -> sock_queue_rcv_skb() + +4. remote TX to local endpoint, multiple-packet message:: + + mctp_pkttype_receive() + : route lookup + -> rt->output() (== mctp_route_input) + : sk_key lookup + : stores skb in struct sk_key->reasm_head + + mctp_pkttype_receive() + : route lookup + -> rt->output() (== mctp_route_input) + : sk_key lookup + : finds existing reassembly in sk_key->reasm_head + : appends new fragment + -> sock_queue_rcv_skb() + +Key refcounts +------------- + + * keys are refed by: + + - a skb: during route output, stored in ``skb->cb``. + + - netns and sock lists. + + * keys can be associated with a device, in which case they hold a + reference to the dev (set through ``key->dev``, counted through + ``dev->key_count``). Multiple keys can reference the device. diff --git a/include/net/mctp.h b/include/net/mctp.h index a824d47c3c6d..b9ed62a63c24 100644 --- a/include/net/mctp.h +++ b/include/net/mctp.h @@ -62,35 +62,46 @@ struct mctp_sock { * by sk->net->keys_lock */ struct hlist_head keys; + + /* mechanism for expiring allocated keys; will release an allocated + * tag, and any netdev state for a request/response pairing + */ + struct timer_list key_expiry; }; /* Key for matching incoming packets to sockets or reassembly contexts. * Packets are matched on (src,dest,tag). * - * Lifetime requirements: + * Lifetime / locking requirements: * - * - keys are free()ed via RCU + * - individual key data (ie, the struct itself) is protected by key->lock; + * changes must be made with that lock held. + * + * - the lookup fields: peer_addr, local_addr and tag are set before the + * key is added to lookup lists, and never updated. + * + * - A ref to the key must be held (throuh key->refs) if a pointer to the + * key is to be accessed after key->lock is released. * * - a mctp_sk_key contains a reference to a struct sock; this is valid * for the life of the key. On sock destruction (through unhash), the key is - * removed from lists (see below), and will not be observable after a RCU - * grace period. - * - * any RX occurring within that grace period may still queue to the socket, - * but will hit the SOCK_DEAD case before the socket is freed. + * removed from lists (see below), and marked invalid. * * - these mctp_sk_keys appear on two lists: * 1) the struct mctp_sock->keys list * 2) the struct netns_mctp->keys list * - * updates to either list are performed under the netns_mctp->keys - * lock. + * presences on these lists requires a (single) refcount to be held; both + * lists are updated as a single operation. + * + * Updates and lookups in either list are performed under the + * netns_mctp->keys lock. Lookup functions will need to lock the key and + * take a reference before unlocking the keys_lock. Consequently, the list's + * keys_lock *cannot* be acquired with the individual key->lock held. * * - a key may have a sk_buff attached as part of an in-progress message - * reassembly (->reasm_head). The reassembly context is protected by - * reasm_lock, which may be acquired with the keys lock (above) held, if - * necessary. Consequently, keys lock *cannot* be acquired with the - * reasm_lock held. + * reassembly (->reasm_head). The reasm data is protected by the individual + * key->lock. * * - there are two destruction paths for a mctp_sk_key: * @@ -101,6 +112,8 @@ struct mctp_sock { * the (complete) reply, or during reassembly errors. Here, we clean up * the reassembly context (marking reasm_dead, to prevent another from * starting), and remove the socket from the netns & socket lists. + * + * - through an expiry timeout, on a per-socket timer */ struct mctp_sk_key { mctp_eid_t peer_addr; @@ -116,14 +129,25 @@ struct mctp_sk_key { /* per-socket list */ struct hlist_node sklist; + /* lock protects against concurrent updates to the reassembly and + * expiry data below. + */ + spinlock_t lock; + + /* Keys are referenced during the output path, which may sleep */ + refcount_t refs; + /* incoming fragment reassembly context */ - spinlock_t reasm_lock; struct sk_buff *reasm_head; struct sk_buff **reasm_tailp; bool reasm_dead; u8 last_seq; - struct rcu_head rcu; + /* key validity */ + bool valid; + + /* expiry timeout; valid (above) cleared on expiry */ + unsigned long expiry; }; struct mctp_skb_cb { @@ -191,6 +215,8 @@ int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb); int mctp_local_output(struct sock *sk, struct mctp_route *rt, struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag); +void mctp_key_unref(struct mctp_sk_key *key); + /* routing <--> device interface */ unsigned int mctp_default_net(struct net *net); int mctp_default_net_set(struct net *net, unsigned int index); diff --git a/include/net/mctpdevice.h b/include/net/mctpdevice.h index 71a11012fac7..3a439463f055 100644 --- a/include/net/mctpdevice.h +++ b/include/net/mctpdevice.h @@ -17,6 +17,8 @@ struct mctp_dev { struct net_device *dev; + refcount_t refs; + unsigned int net; /* Only modified under RTNL. Reads have addrs_lock held */ @@ -32,4 +34,7 @@ struct mctp_dev { struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev); struct mctp_dev *__mctp_dev_get(const struct net_device *dev); +void mctp_dev_hold(struct mctp_dev *mdev); +void mctp_dev_put(struct mctp_dev *mdev); + #endif /* __NET_MCTPDEVICE_H */ diff --git a/include/trace/events/mctp.h b/include/trace/events/mctp.h new file mode 100644 index 000000000000..175b057c507f --- /dev/null +++ b/include/trace/events/mctp.h @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM mctp + +#if !defined(_TRACE_MCTP_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_MCTP_H + +#include + +#ifndef __TRACE_MCTP_ENUMS +#define __TRACE_MCTP_ENUMS +enum { + MCTP_TRACE_KEY_TIMEOUT, + MCTP_TRACE_KEY_REPLIED, + MCTP_TRACE_KEY_INVALIDATED, + MCTP_TRACE_KEY_CLOSED, +}; +#endif /* __TRACE_MCTP_ENUMS */ + +TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_TIMEOUT); +TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_REPLIED); +TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_INVALIDATED); +TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_CLOSED); + +TRACE_EVENT(mctp_key_acquire, + TP_PROTO(const struct mctp_sk_key *key), + TP_ARGS(key), + TP_STRUCT__entry( + __field(__u8, paddr) + __field(__u8, laddr) + __field(__u8, tag) + ), + TP_fast_assign( + __entry->paddr = key->peer_addr; + __entry->laddr = key->local_addr; + __entry->tag = key->tag; + ), + TP_printk("local %d, peer %d, tag %1x", + __entry->laddr, + __entry->paddr, + __entry->tag + ) +); + +TRACE_EVENT(mctp_key_release, + TP_PROTO(const struct mctp_sk_key *key, int reason), + TP_ARGS(key, reason), + TP_STRUCT__entry( + __field(__u8, paddr) + __field(__u8, laddr) + __field(__u8, tag) + __field(int, reason) + ), + TP_fast_assign( + __entry->paddr = key->peer_addr; + __entry->laddr = key->local_addr; + __entry->tag = key->tag; + __entry->reason = reason; + ), + TP_printk("local %d, peer %d, tag %1x %s", + __entry->laddr, + __entry->paddr, + __entry->tag, + __print_symbolic(__entry->reason, + { MCTP_TRACE_KEY_TIMEOUT, "timeout" }, + { MCTP_TRACE_KEY_REPLIED, "replied" }, + { MCTP_TRACE_KEY_INVALIDATED, "invalidated" }, + { MCTP_TRACE_KEY_CLOSED, "closed" }) + ) +); + +#endif + +#include diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index a9526ac29dff..66a411d60b6c 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -16,6 +16,9 @@ #include #include +#define CREATE_TRACE_POINTS +#include + /* socket implementation */ static int mctp_release(struct socket *sock) @@ -223,16 +226,61 @@ static const struct proto_ops mctp_dgram_ops = { .sendpage = sock_no_sendpage, }; +static void mctp_sk_expire_keys(struct timer_list *timer) +{ + struct mctp_sock *msk = container_of(timer, struct mctp_sock, + key_expiry); + struct net *net = sock_net(&msk->sk); + unsigned long next_expiry, flags; + struct mctp_sk_key *key; + struct hlist_node *tmp; + bool next_expiry_valid = false; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + spin_lock(&key->lock); + + if (!time_after_eq(key->expiry, jiffies)) { + trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT); + key->valid = false; + hlist_del_rcu(&key->hlist); + hlist_del_rcu(&key->sklist); + spin_unlock(&key->lock); + mctp_key_unref(key); + continue; + } + + if (next_expiry_valid) { + if (time_before(key->expiry, next_expiry)) + next_expiry = key->expiry; + } else { + next_expiry = key->expiry; + next_expiry_valid = true; + } + spin_unlock(&key->lock); + } + + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + if (next_expiry_valid) + mod_timer(timer, next_expiry); +} + static int mctp_sk_init(struct sock *sk) { struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); INIT_HLIST_HEAD(&msk->keys); + timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); return 0; } static void mctp_sk_close(struct sock *sk, long timeout) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + del_timer_sync(&msk->key_expiry); sk_common_release(sk); } @@ -263,21 +311,23 @@ static void mctp_sk_unhash(struct sock *sk) /* remove tag allocations */ spin_lock_irqsave(&net->mctp.keys_lock, flags); hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { - hlist_del_rcu(&key->sklist); - hlist_del_rcu(&key->hlist); + hlist_del(&key->sklist); + hlist_del(&key->hlist); - spin_lock(&key->reasm_lock); + trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED); + + spin_lock(&key->lock); if (key->reasm_head) kfree_skb(key->reasm_head); key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock(&key->reasm_lock); + key->valid = false; + spin_unlock(&key->lock); - kfree_rcu(key, rcu); + /* key is no longer on the lookup lists, unref */ + mctp_key_unref(key); } spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - - synchronize_rcu(); } static struct proto mctp_proto = { @@ -385,7 +435,7 @@ static __exit void mctp_exit(void) sock_unregister(PF_MCTP); } -module_init(mctp_init); +subsys_initcall(mctp_init); module_exit(mctp_exit); MODULE_DESCRIPTION("MCTP core"); diff --git a/net/mctp/device.c b/net/mctp/device.c index b9f38e765f61..3827d62f52c9 100644 --- a/net/mctp/device.c +++ b/net/mctp/device.c @@ -35,14 +35,6 @@ struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev) return rtnl_dereference(dev->mctp_ptr); } -static void mctp_dev_destroy(struct mctp_dev *mdev) -{ - struct net_device *dev = mdev->dev; - - dev_put(dev); - kfree_rcu(mdev, rcu); -} - static int mctp_fill_addrinfo(struct sk_buff *skb, struct netlink_callback *cb, struct mctp_dev *mdev, mctp_eid_t eid) { @@ -255,6 +247,19 @@ static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, return 0; } +void mctp_dev_hold(struct mctp_dev *mdev) +{ + refcount_inc(&mdev->refs); +} + +void mctp_dev_put(struct mctp_dev *mdev) +{ + if (refcount_dec_and_test(&mdev->refs)) { + dev_put(mdev->dev); + kfree_rcu(mdev, rcu); + } +} + static struct mctp_dev *mctp_add_dev(struct net_device *dev) { struct mctp_dev *mdev; @@ -270,7 +275,9 @@ static struct mctp_dev *mctp_add_dev(struct net_device *dev) mdev->net = mctp_default_net(dev_net(dev)); /* associate to net_device */ + refcount_set(&mdev->refs, 1); rcu_assign_pointer(dev->mctp_ptr, mdev); + dev_hold(dev); mdev->dev = dev; @@ -330,12 +337,26 @@ static int mctp_set_link_af(struct net_device *dev, const struct nlattr *attr, return 0; } +/* Matches netdev types that should have MCTP handling */ +static bool mctp_known(struct net_device *dev) +{ + /* only register specific types (inc. NONE for TUN devices) */ + return dev->type == ARPHRD_MCTP || + dev->type == ARPHRD_LOOPBACK || + dev->type == ARPHRD_NONE; +} + static void mctp_unregister(struct net_device *dev) { struct mctp_dev *mdev; mdev = mctp_dev_get_rtnl(dev); - + if (mctp_known(dev) != (bool)mdev) { + // Sanity check, should match what was set in mctp_register + netdev_warn(dev, "%s: mdev pointer %d but type (%d) match is %d", + __func__, (bool)mdev, mctp_known(dev), dev->type); + return; + } if (!mdev) return; @@ -345,7 +366,7 @@ static void mctp_unregister(struct net_device *dev) mctp_neigh_remove_dev(mdev); kfree(mdev->addrs); - mctp_dev_destroy(mdev); + mctp_dev_put(mdev); } static int mctp_register(struct net_device *dev) @@ -353,11 +374,17 @@ static int mctp_register(struct net_device *dev) struct mctp_dev *mdev; /* Already registered? */ - if (rtnl_dereference(dev->mctp_ptr)) - return 0; + mdev = rtnl_dereference(dev->mctp_ptr); - /* only register specific types; MCTP-specific and loopback for now */ - if (dev->type != ARPHRD_MCTP && dev->type != ARPHRD_LOOPBACK) + if (mdev) { + if (!mctp_known(dev)) + netdev_warn(dev, "%s: mctp_dev set for unknown type %d", + __func__, dev->type); + return 0; + } + + /* only register specific types */ + if (!mctp_known(dev)) return 0; mdev = mctp_add_dev(dev); diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c index 90ed2f02d1fb..5cc042121493 100644 --- a/net/mctp/neigh.c +++ b/net/mctp/neigh.c @@ -47,7 +47,7 @@ static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid, } INIT_LIST_HEAD(&neigh->list); neigh->dev = mdev; - dev_hold(neigh->dev->dev); + mctp_dev_hold(neigh->dev); neigh->eid = eid; neigh->source = source; memcpy(neigh->ha, lladdr, lladdr_len); @@ -63,7 +63,7 @@ static void __mctp_neigh_free(struct rcu_head *rcu) { struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu); - dev_put(neigh->dev->dev); + mctp_dev_put(neigh->dev); kfree(neigh); } diff --git a/net/mctp/route.c b/net/mctp/route.c index 5ca186d53cb0..e20f3096d067 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -23,7 +23,10 @@ #include #include +#include + static const unsigned int mctp_message_maxlen = 64 * 1024; +static const unsigned long mctp_key_lifetime = 6 * CONFIG_HZ; /* route output callbacks */ static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb) @@ -83,25 +86,43 @@ static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local, return true; } +/* returns a key (with key->lock held, and refcounted), or NULL if no such + * key exists. + */ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, - mctp_eid_t peer) + mctp_eid_t peer, + unsigned long *irqflags) + __acquires(&key->lock) { struct mctp_sk_key *key, *ret; + unsigned long flags; struct mctp_hdr *mh; u8 tag; - WARN_ON(!rcu_read_lock_held()); - mh = mctp_hdr(skb); tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); ret = NULL; + spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) { - if (mctp_key_match(key, mh->dest, peer, tag)) { + hlist_for_each_entry(key, &net->mctp.keys, hlist) { + if (!mctp_key_match(key, mh->dest, peer, tag)) + continue; + + spin_lock(&key->lock); + if (key->valid) { + refcount_inc(&key->refs); ret = key; break; } + spin_unlock(&key->lock); + } + + if (ret) { + spin_unlock(&net->mctp.keys_lock); + *irqflags = flags; + } else { + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); } return ret; @@ -121,11 +142,19 @@ static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, key->local_addr = local; key->tag = tag; key->sk = &msk->sk; - spin_lock_init(&key->reasm_lock); + key->valid = true; + spin_lock_init(&key->lock); + refcount_set(&key->refs, 1); return key; } +void mctp_key_unref(struct mctp_sk_key *key) +{ + if (refcount_dec_and_test(&key->refs)) + kfree(key); +} + static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) { struct net *net = sock_net(&msk->sk); @@ -138,12 +167,20 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { if (mctp_key_match(tmp, key->local_addr, key->peer_addr, key->tag)) { - rc = -EEXIST; - break; + spin_lock(&tmp->lock); + if (tmp->valid) + rc = -EEXIST; + spin_unlock(&tmp->lock); + if (rc) + break; } } if (!rc) { + refcount_inc(&key->refs); + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + hlist_add_head(&key->hlist, &net->mctp.keys); hlist_add_head(&key->sklist, &msk->keys); } @@ -153,28 +190,35 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) return rc; } -/* Must be called with key->reasm_lock, which it will release. Will schedule - * the key for an RCU free. +/* We're done with the key; unset valid and remove from lists. There may still + * be outstanding refs on the key though... */ static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net, unsigned long flags) - __releases(&key->reasm_lock) + __releases(&key->lock) { struct sk_buff *skb; skb = key->reasm_head; key->reasm_head = NULL; key->reasm_dead = true; - spin_unlock_irqrestore(&key->reasm_lock, flags); + key->valid = false; + spin_unlock_irqrestore(&key->lock, flags); spin_lock_irqsave(&net->mctp.keys_lock, flags); - hlist_del_rcu(&key->hlist); - hlist_del_rcu(&key->sklist); + hlist_del(&key->hlist); + hlist_del(&key->sklist); spin_unlock_irqrestore(&net->mctp.keys_lock, flags); - kfree_rcu(key, rcu); + + /* one unref for the lists */ + mctp_key_unref(key); + + /* and one for the local reference */ + mctp_key_unref(key); if (skb) kfree_skb(skb); + } static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) @@ -248,8 +292,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) rcu_read_lock(); - /* lookup socket / reasm context, exactly matching (src,dest,tag) */ - key = mctp_lookup_key(net, skb, mh->src); + /* lookup socket / reasm context, exactly matching (src,dest,tag). + * we hold a ref on the key, and key->lock held. + */ + key = mctp_lookup_key(net, skb, mh->src, &f); if (flags & MCTP_HDR_FLAG_SOM) { if (key) { @@ -260,10 +306,12 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * key for reassembly - we'll create a more specific * one for future packets if required (ie, !EOM). */ - key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY); + key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); if (key) { msk = container_of(key->sk, struct mctp_sock, sk); + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); key = NULL; } } @@ -282,11 +330,13 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (flags & MCTP_HDR_FLAG_EOM) { sock_queue_rcv_skb(&msk->sk, skb); if (key) { - spin_lock_irqsave(&key->reasm_lock, f); /* we've hit a pending reassembly; not much we * can do but drop it */ + trace_mctp_key_release(key, + MCTP_TRACE_KEY_REPLIED); __mctp_key_unlock_drop(key, net, f); + key = NULL; } rc = 0; goto out_unlock; @@ -303,7 +353,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) goto out_unlock; } - /* we can queue without the reasm lock here, as the + /* we can queue without the key lock here, as the * key isn't observable yet */ mctp_frag_queue(key, skb); @@ -318,17 +368,21 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (rc) kfree(key); - } else { - /* existing key: start reassembly */ - spin_lock_irqsave(&key->reasm_lock, f); + trace_mctp_key_acquire(key); + /* we don't need to release key->lock on exit */ + key = NULL; + + } else { if (key->reasm_head || key->reasm_dead) { /* duplicate start? drop everything */ + trace_mctp_key_release(key, + MCTP_TRACE_KEY_INVALIDATED); __mctp_key_unlock_drop(key, net, f); rc = -EEXIST; + key = NULL; } else { rc = mctp_frag_queue(key, skb); - spin_unlock_irqrestore(&key->reasm_lock, f); } } @@ -337,8 +391,6 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) * using the message-specific key */ - spin_lock_irqsave(&key->reasm_lock, f); - /* we need to be continuing an existing reassembly... */ if (!key->reasm_head) rc = -EINVAL; @@ -351,9 +403,9 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) if (!rc && flags & MCTP_HDR_FLAG_EOM) { sock_queue_rcv_skb(key->sk, key->reasm_head); key->reasm_head = NULL; + trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED); __mctp_key_unlock_drop(key, net, f); - } else { - spin_unlock_irqrestore(&key->reasm_lock, f); + key = NULL; } } else { @@ -363,6 +415,10 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) out_unlock: rcu_read_unlock(); + if (key) { + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); + } out: if (rc) kfree_skb(skb); @@ -412,7 +468,7 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb) static void mctp_route_release(struct mctp_route *rt) { if (refcount_dec_and_test(&rt->refs)) { - dev_put(rt->dev->dev); + mctp_dev_put(rt->dev); kfree_rcu(rt, rcu); } } @@ -454,11 +510,15 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, lockdep_assert_held(&mns->keys_lock); + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + /* we hold the net->key_lock here, allowing updates to both * then net and sk */ hlist_add_head_rcu(&key->hlist, &mns->keys); hlist_add_head_rcu(&key->sklist, &msk->keys); + refcount_inc(&key->refs); } /* Allocate a locally-owned tag value for (saddr, daddr), and reserve @@ -474,6 +534,10 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, int rc = -EAGAIN; u8 tagbits; + /* for NULL destination EIDs, we may get a response from any peer */ + if (daddr == MCTP_ADDR_NULL) + daddr = MCTP_ADDR_ANY; + /* be optimistic, alloc now */ key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL); if (!key) @@ -488,14 +552,26 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, * tags. If we find a conflict, clear that bit from tagbits */ hlist_for_each_entry(tmp, &mns->keys, hlist) { + /* We can check the lookup fields (*_addr, tag) without the + * lock held, they don't change over the lifetime of the key. + */ + /* if we don't own the tag, it can't conflict */ if (tmp->tag & MCTP_HDR_FLAG_TO) continue; - if ((tmp->peer_addr == daddr || - tmp->peer_addr == MCTP_ADDR_ANY) && - tmp->local_addr == saddr) + if (!((tmp->peer_addr == daddr || + tmp->peer_addr == MCTP_ADDR_ANY) && + tmp->local_addr == saddr)) + continue; + + spin_lock(&tmp->lock); + /* key must still be valid. If we find a match, clear the + * potential tag value + */ + if (tmp->valid) tagbits &= ~(1 << tmp->tag); + spin_unlock(&tmp->lock); if (!tagbits) break; @@ -504,6 +580,8 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, if (tagbits) { key->tag = __ffs(tagbits); mctp_reserve_tag(net, key, msk); + trace_mctp_key_acquire(key); + *tagp = key->tag; rc = 0; } @@ -552,6 +630,20 @@ struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet, return rt; } +static struct mctp_route *mctp_route_lookup_null(struct net *net, + struct net_device *dev) +{ + struct mctp_route *rt; + + list_for_each_entry_rcu(rt, &net->mctp.routes, list) { + if (rt->dev->dev == dev && rt->type == RTN_LOCAL && + refcount_inc_not_zero(&rt->refs)) + return rt; + } + + return NULL; +} + /* sends a skb to rt and releases the route. */ int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb) { @@ -741,7 +833,7 @@ static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start, rt->max = daddr_start + daddr_extent; rt->mtu = mtu; rt->dev = mdev; - dev_hold(rt->dev->dev); + mctp_dev_hold(rt->dev); rt->type = type; rt->output = rtfn; @@ -821,13 +913,18 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev, struct net_device *orig_dev) { struct net *net = dev_net(dev); + struct mctp_dev *mdev; struct mctp_skb_cb *cb; struct mctp_route *rt; struct mctp_hdr *mh; - /* basic non-data sanity checks */ - if (dev->type != ARPHRD_MCTP) + rcu_read_lock(); + mdev = __mctp_dev_get(dev); + rcu_read_unlock(); + if (!mdev) { + /* basic non-data sanity checks */ goto err_drop; + } if (!pskb_may_pull(skb, sizeof(struct mctp_hdr))) goto err_drop; @@ -841,11 +938,14 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev, goto err_drop; cb = __mctp_cb(skb); - rcu_read_lock(); - cb->net = READ_ONCE(__mctp_dev_get(dev)->net); - rcu_read_unlock(); + cb->net = READ_ONCE(mdev->net); rt = mctp_route_lookup(net, cb->net, mh->dest); + + /* NULL EID, but addressed to our physical address */ + if (!rt && mh->dest == MCTP_ADDR_NULL && skb->pkt_type == PACKET_HOST) + rt = mctp_route_lookup_null(net, dev); + if (!rt) goto err_drop; @@ -926,10 +1026,15 @@ static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh, return 0; } +static const struct nla_policy rta_metrics_policy[RTAX_MAX + 1] = { + [RTAX_MTU] = { .type = NLA_U32 }, +}; + static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { struct nlattr *tb[RTA_MAX + 1]; + struct nlattr *tbx[RTAX_MAX + 1]; mctp_eid_t daddr_start; struct mctp_dev *mdev; struct rtmsg *rtm; @@ -946,8 +1051,15 @@ static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; } - /* TODO: parse mtu from nlparse */ mtu = 0; + if (tb[RTA_METRICS]) { + rc = nla_parse_nested(tbx, RTAX_MAX, tb[RTA_METRICS], + rta_metrics_policy, NULL); + if (rc < 0) + return rc; + if (tbx[RTAX_MTU]) + mtu = nla_get_u32(tbx[RTAX_MTU]); + } if (rtm->rtm_type != RTN_UNICAST) return -EINVAL;