packet: fix send path when running with proto == 0

Commit e40526cb20 introduced a cached dev pointer, that gets
hooked into register_prot_hook(), __unregister_prot_hook() to
update the device used for the send path.

We need to fix this up, as otherwise this will not work with
sockets created with protocol = 0, plus with sll_protocol = 0
passed via sockaddr_ll when doing the bind.

So instead, assign the pointer directly. The compiler can inline
these helper functions automagically.

While at it, also assume the cached dev fast-path as likely(),
and document this variant of socket creation as it seems it is
not widely used (seems not even the author of TX_RING was aware
of that in his reference example [1]). Tested with reproducer
from e40526cb20.

 [1] http://wiki.ipxwarzone.com/index.php5?title=Linux_packet_mmap#Example

Fixes: e40526cb20 ("packet: fix use after free race in send path when dev is released")
Signed-off-by: Daniel Borkmann <dborkman@redhat.com>
Tested-by: Salam Noureddine <noureddine@aristanetworks.com>
Tested-by: Jesper Dangaard Brouer <brouer@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Daniel Borkmann 2013-12-06 11:36:15 +01:00 committed by David S. Miller
parent 98bfd23cdb
commit 66e56cd46b
2 changed files with 50 additions and 25 deletions

View File

@ -123,6 +123,16 @@ Transmission process is similar to capture as shown below.
[shutdown] close() --------> destruction of the transmission socket and [shutdown] close() --------> destruction of the transmission socket and
deallocation of all associated resources. deallocation of all associated resources.
Socket creation and destruction is also straight forward, and is done
the same way as in capturing described in the previous paragraph:
int fd = socket(PF_PACKET, mode, 0);
The protocol can optionally be 0 in case we only want to transmit
via this socket, which avoids an expensive call to packet_rcv().
In this case, you also need to bind(2) the TX_RING with sll_protocol = 0
set. Otherwise, htons(ETH_P_ALL) or any other protocol, for example.
Binding the socket to your network interface is mandatory (with zero copy) to Binding the socket to your network interface is mandatory (with zero copy) to
know the header size of frames used in the circular buffer. know the header size of frames used in the circular buffer.

View File

@ -237,6 +237,30 @@ struct packet_skb_cb {
static void __fanout_unlink(struct sock *sk, struct packet_sock *po); static void __fanout_unlink(struct sock *sk, struct packet_sock *po);
static void __fanout_link(struct sock *sk, struct packet_sock *po); static void __fanout_link(struct sock *sk, struct packet_sock *po);
static struct net_device *packet_cached_dev_get(struct packet_sock *po)
{
struct net_device *dev;
rcu_read_lock();
dev = rcu_dereference(po->cached_dev);
if (likely(dev))
dev_hold(dev);
rcu_read_unlock();
return dev;
}
static void packet_cached_dev_assign(struct packet_sock *po,
struct net_device *dev)
{
rcu_assign_pointer(po->cached_dev, dev);
}
static void packet_cached_dev_reset(struct packet_sock *po)
{
RCU_INIT_POINTER(po->cached_dev, NULL);
}
/* register_prot_hook must be invoked with the po->bind_lock held, /* register_prot_hook must be invoked with the po->bind_lock held,
* or from a context in which asynchronous accesses to the packet * or from a context in which asynchronous accesses to the packet
* socket is not possible (packet_create()). * socket is not possible (packet_create()).
@ -246,12 +270,10 @@ static void register_prot_hook(struct sock *sk)
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
if (!po->running) { if (!po->running) {
if (po->fanout) { if (po->fanout)
__fanout_link(sk, po); __fanout_link(sk, po);
} else { else
dev_add_pack(&po->prot_hook); dev_add_pack(&po->prot_hook);
rcu_assign_pointer(po->cached_dev, po->prot_hook.dev);
}
sock_hold(sk); sock_hold(sk);
po->running = 1; po->running = 1;
@ -270,12 +292,11 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
po->running = 0; po->running = 0;
if (po->fanout) {
if (po->fanout)
__fanout_unlink(sk, po); __fanout_unlink(sk, po);
} else { else
__dev_remove_pack(&po->prot_hook); __dev_remove_pack(&po->prot_hook);
RCU_INIT_POINTER(po->cached_dev, NULL);
}
__sock_put(sk); __sock_put(sk);
@ -2059,19 +2080,6 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
return tp_len; return tp_len;
} }
static struct net_device *packet_cached_dev_get(struct packet_sock *po)
{
struct net_device *dev;
rcu_read_lock();
dev = rcu_dereference(po->cached_dev);
if (dev)
dev_hold(dev);
rcu_read_unlock();
return dev;
}
static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
{ {
struct sk_buff *skb; struct sk_buff *skb;
@ -2088,7 +2096,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
mutex_lock(&po->pg_vec_lock); mutex_lock(&po->pg_vec_lock);
if (saddr == NULL) { if (likely(saddr == NULL)) {
dev = packet_cached_dev_get(po); dev = packet_cached_dev_get(po);
proto = po->num; proto = po->num;
addr = NULL; addr = NULL;
@ -2242,7 +2250,7 @@ static int packet_snd(struct socket *sock,
* Get and verify the address. * Get and verify the address.
*/ */
if (saddr == NULL) { if (likely(saddr == NULL)) {
dev = packet_cached_dev_get(po); dev = packet_cached_dev_get(po);
proto = po->num; proto = po->num;
addr = NULL; addr = NULL;
@ -2451,6 +2459,8 @@ static int packet_release(struct socket *sock)
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
unregister_prot_hook(sk, false); unregister_prot_hook(sk, false);
packet_cached_dev_reset(po);
if (po->prot_hook.dev) { if (po->prot_hook.dev) {
dev_put(po->prot_hook.dev); dev_put(po->prot_hook.dev);
po->prot_hook.dev = NULL; po->prot_hook.dev = NULL;
@ -2506,14 +2516,17 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
unregister_prot_hook(sk, true); unregister_prot_hook(sk, true);
po->num = protocol; po->num = protocol;
po->prot_hook.type = protocol; po->prot_hook.type = protocol;
if (po->prot_hook.dev) if (po->prot_hook.dev)
dev_put(po->prot_hook.dev); dev_put(po->prot_hook.dev);
po->prot_hook.dev = dev;
po->prot_hook.dev = dev;
po->ifindex = dev ? dev->ifindex : 0; po->ifindex = dev ? dev->ifindex : 0;
packet_cached_dev_assign(po, dev);
if (protocol == 0) if (protocol == 0)
goto out_unlock; goto out_unlock;
@ -2626,7 +2639,8 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
po = pkt_sk(sk); po = pkt_sk(sk);
sk->sk_family = PF_PACKET; sk->sk_family = PF_PACKET;
po->num = proto; po->num = proto;
RCU_INIT_POINTER(po->cached_dev, NULL);
packet_cached_dev_reset(po);
sk->sk_destruct = packet_sock_destruct; sk->sk_destruct = packet_sock_destruct;
sk_refcnt_debug_inc(sk); sk_refcnt_debug_inc(sk);
@ -3337,6 +3351,7 @@ static int packet_notifier(struct notifier_block *this,
sk->sk_error_report(sk); sk->sk_error_report(sk);
} }
if (msg == NETDEV_UNREGISTER) { if (msg == NETDEV_UNREGISTER) {
packet_cached_dev_reset(po);
po->ifindex = -1; po->ifindex = -1;
if (po->prot_hook.dev) if (po->prot_hook.dev)
dev_put(po->prot_hook.dev); dev_put(po->prot_hook.dev);