Merge branch 'tls-rx-avoid-skb_cow_data'

Jakub Kicinski says:

====================
tls: rx: avoid skb_cow_data()

TLS calls skb_cow_data() on the skb it received from strparser
whenever it needs to hold onto the skb with the decrypted data.
(The alternative being decrypting directly to a user space buffer
in whic case the input skb doesn't get modified or used after.)
TLS needs the decrypted skb:
 - almost always with TLS 1.3 (unless the new NoPad is enabled);
 - when user space buffer is too small to fit the record;
 - when BPF sockmap is enabled.

Most of the time the skb we get out of strparser is a clone of
a 64kB data unit coalsced by GRO. To make things worse skb_cow_data()
tries to output a linear skb and allocates it with GFP_ATOMIC.
This occasionally fails even under moderate memory pressure.

This patch set rejigs the TLS Rx so that we don't expect decryption
in place. The decryption handlers return an skb which may or may not
be the skb from strparser. For TLS 1.3 this results in a 20-30%
performance improvement without NoPad enabled.

v2: rebase after 3d8c51b25a ("net/tls: Check for errors in tls_device_init")
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
David S. Miller 2022-07-18 11:24:11 +01:00
commit fd18d5f132
7 changed files with 334 additions and 195 deletions

View File

@ -72,7 +72,6 @@ struct sk_skb_cb {
/* strp users' data follows */
struct tls_msg {
u8 control;
u8 decrypted;
} tls;
/* temp_reg is a temporary register used for bpf_convert_data_end_access
* when dst_reg == src_reg.

View File

@ -116,11 +116,15 @@ struct tls_sw_context_rx {
void (*saved_data_ready)(struct sock *sk);
struct sk_buff *recv_pkt;
u8 reader_present;
u8 async_capable:1;
u8 zc_capable:1;
u8 reader_contended:1;
atomic_t decrypt_pending;
/* protect crypto_wait with decrypt_pending*/
spinlock_t decrypt_compl_lock;
struct sk_buff_head async_hold;
struct wait_queue_head wq;
};
struct tls_record_info {

View File

@ -7,7 +7,7 @@ CFLAGS_trace.o := -I$(src)
obj-$(CONFIG_TLS) += tls.o
tls-y := tls_main.o tls_sw.o tls_proc.o trace.o
tls-y := tls_main.o tls_sw.o tls_proc.o trace.o tls_strp.o
tls-$(CONFIG_TLS_TOE) += tls_toe.o
tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o

View File

@ -39,6 +39,9 @@
#include <linux/skmsg.h>
#include <net/tls.h>
#define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \
TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT))
#define __TLS_INC_STATS(net, field) \
__SNMP_INC_STATS((net)->mib.tls_statistics, field)
#define TLS_INC_STATS(net, field) \
@ -118,13 +121,15 @@ void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type);
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout);
int decrypt_skb(struct sock *sk, struct scatterlist *sgout);
int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info);
int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb,
struct sk_buff_head *dst);
static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
@ -132,6 +137,11 @@ static inline struct tls_msg *tls_msg(struct sk_buff *skb)
return &scb->tls;
}
static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx)
{
return ctx->recv_pkt;
}
#ifdef CONFIG_TLS_DEVICE
int tls_device_init(void);
void tls_device_cleanup(void);
@ -140,8 +150,7 @@ void tls_device_free_resources_tx(struct sock *sk);
int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
void tls_device_offload_cleanup_rx(struct sock *sk);
void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm);
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx);
#else
static inline int tls_device_init(void) { return 0; }
static inline void tls_device_cleanup(void) {}
@ -165,8 +174,7 @@ static inline void
tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
static inline int
tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm)
tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
{
return 0;
}

View File

@ -889,14 +889,19 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
}
}
static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
static int
tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx)
{
struct strp_msg *rxm = strp_msg(skb);
int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
struct sk_buff *skb_iter, *unused;
int err = 0, offset, copy, nsg, data_len, pos;
struct sk_buff *skb, *skb_iter, *unused;
struct scatterlist sg[1];
struct strp_msg *rxm;
char *orig_buf, *buf;
skb = tls_strp_msg(sw_ctx);
rxm = strp_msg(skb);
offset = rxm->offset;
orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
if (!orig_buf)
@ -919,7 +924,7 @@ static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
goto free_buf;
/* We are interested only in the decrypted data not the auth */
err = decrypt_skb(sk, skb, sg);
err = decrypt_skb(sk, sg);
if (err != -EBADMSG)
goto free_buf;
else
@ -974,10 +979,12 @@ free_buf:
return err;
}
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm)
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
{
struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb = tls_strp_msg(sw_ctx);
struct strp_msg *rxm = strp_msg(skb);
int is_decrypted = skb->decrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter;
@ -1000,7 +1007,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
* likely have initial fragments decrypted, and final ones not
* decrypted. We need to reencrypt that single SKB.
*/
return tls_device_reencrypt(sk, skb);
return tls_device_reencrypt(sk, sw_ctx);
}
/* Return immediately if the record is either entirely plaintext or
@ -1017,7 +1024,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
}
ctx->resync_nh_reset = 1;
return tls_device_reencrypt(sk, skb);
return tls_device_reencrypt(sk, sw_ctx);
}
static void tls_device_attach(struct tls_context *ctx, struct sock *sk,

17
net/tls/tls_strp.c Normal file
View File

@ -0,0 +1,17 @@
// SPDX-License-Identifier: GPL-2.0-only
#include <linux/skbuff.h>
#include "tls.h"
int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb,
struct sk_buff_head *dst)
{
struct sk_buff *clone;
clone = skb_clone(skb, sk->sk_allocation);
if (!clone)
return -ENOMEM;
__skb_queue_tail(dst, clone);
return 0;
}

View File

@ -47,9 +47,13 @@
#include "tls.h"
struct tls_decrypt_arg {
struct_group(inargs,
bool zc;
bool async;
u8 tail;
);
struct sk_buff *skb;
};
struct tls_decrypt_ctx {
@ -180,39 +184,22 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx;
struct tls_prot_info *prot;
struct scatterlist *sg;
struct sk_buff *skb;
unsigned int pages;
struct sock *sk;
skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk);
sk = (struct sock *)req->data;
tls_ctx = tls_get_ctx(sk);
ctx = tls_sw_ctx_rx(tls_ctx);
prot = &tls_ctx->prot_info;
/* Propagate if there was an err */
if (err) {
if (err == -EBADMSG)
TLS_INC_STATS(sock_net(skb->sk),
LINUX_MIB_TLSDECRYPTERROR);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
ctx->async_wait.err = err;
tls_err_abort(skb->sk, err);
} else {
struct strp_msg *rxm = strp_msg(skb);
/* No TLS 1.3 support with async crypto */
WARN_ON(prot->tail_size);
rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size;
tls_err_abort(sk, err);
}
/* After using skb->sk to propagate sk through crypto async callback
* we need to NULL it again.
*/
skb->sk = NULL;
/* Free the destination pages if skb was not decrypted inplace */
if (sgout != sgin) {
/* Skip the first S/G entry as it points to AAD */
@ -232,7 +219,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
}
static int tls_do_decryption(struct sock *sk,
struct sk_buff *skb,
struct scatterlist *sgin,
struct scatterlist *sgout,
char *iv_recv,
@ -252,16 +238,9 @@ static int tls_do_decryption(struct sock *sk,
(u8 *)iv_recv);
if (darg->async) {
/* Using skb->sk to push sk through to crypto async callback
* handler. This allows propagating errors up to the socket
* if needed. It _must_ be cleared in the async handler
* before consume_skb is called. We _know_ skb->sk is NULL
* because it is a clone from strparser.
*/
skb->sk = sk;
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_decrypt_done, skb);
tls_decrypt_done, sk);
atomic_inc(&ctx->decrypt_pending);
} else {
aead_request_set_callback(aead_req,
@ -1404,51 +1383,90 @@ out:
return rc;
}
static struct sk_buff *
tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
unsigned int full_len)
{
struct strp_msg *clr_rxm;
struct sk_buff *clr_skb;
int err;
clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
&err, sk->sk_allocation);
if (!clr_skb)
return NULL;
skb_copy_header(clr_skb, skb);
clr_skb->len = full_len;
clr_skb->data_len = full_len;
clr_rxm = strp_msg(clr_skb);
clr_rxm->offset = 0;
return clr_skb;
}
/* Decrypt handlers
*
* tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers.
* They must transform the darg in/out argument are as follows:
* | Input | Output
* -------------------------------------------------------------------
* zc | Zero-copy decrypt allowed | Zero-copy performed
* async | Async decrypt allowed | Async crypto used / in progress
* skb | * | Output skb
*/
/* This function decrypts the input skb into either out_iov or in out_sg
* or in skb buffers itself. The input parameter 'zc' indicates if
* or in skb buffers itself. The input parameter 'darg->zc' indicates if
* zero-copy mode needs to be tried or not. With zero-copy mode, either
* out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
* NULL, then the decryption happens inside skb buffers itself, i.e.
* zero-copy gets disabled and 'zc' is updated.
* zero-copy gets disabled and 'darg->zc' is updated.
*/
static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct iov_iter *out_iov,
struct scatterlist *out_sg,
struct tls_decrypt_arg *darg)
static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
struct scatterlist *out_sg,
struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
int n_sgin, n_sgout, aead_size, err, pages = 0;
struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
struct sk_buff *skb = tls_strp_msg(ctx);
const struct strp_msg *rxm = strp_msg(skb);
const struct tls_msg *tlm = tls_msg(skb);
struct aead_request *aead_req;
struct sk_buff *unused;
struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL;
const int data_len = rxm->full_len - prot->overhead_size;
int tail_pages = !!prot->tail_size;
struct tls_decrypt_ctx *dctx;
struct sk_buff *clear_skb;
int iv_offset = 0;
u8 *mem;
n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
rxm->full_len - prot->prepend_size);
if (n_sgin < 1)
return n_sgin ?: -EBADMSG;
if (darg->zc && (out_iov || out_sg)) {
clear_skb = NULL;
if (out_iov)
n_sgout = 1 + tail_pages +
iov_iter_npages_cap(out_iov, INT_MAX, data_len);
else
n_sgout = sg_nents(out_sg);
n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
rxm->full_len - prot->prepend_size);
} else {
n_sgout = 0;
darg->zc = false;
n_sgin = skb_cow_data(skb, 0, &unused);
}
if (n_sgin < 1)
return -EBADMSG;
clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
if (!clear_skb)
return -ENOMEM;
n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
}
/* Increment to accommodate AAD */
n_sgin = n_sgin + 1;
@ -1460,8 +1478,10 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
sk->sk_allocation);
if (!mem)
return -ENOMEM;
if (!mem) {
err = -ENOMEM;
goto exit_free_skb;
}
/* Segment the allocated memory */
aead_req = (struct aead_request *)mem;
@ -1510,88 +1530,107 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
if (err < 0)
goto exit_free;
if (n_sgout) {
if (out_iov) {
sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
if (clear_skb) {
sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
err = tls_setup_from_iter(out_iov, data_len,
&pages, &sgout[1],
(n_sgout - 1 - tail_pages));
if (err < 0)
goto fallback_to_reg_recv;
err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
data_len + prot->tail_size);
if (err < 0)
goto exit_free;
} else if (out_iov) {
sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
if (prot->tail_size) {
sg_unmark_end(&sgout[pages]);
sg_set_buf(&sgout[pages + 1], &dctx->tail,
prot->tail_size);
sg_mark_end(&sgout[pages + 1]);
}
} else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
} else {
goto fallback_to_reg_recv;
err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
(n_sgout - 1 - tail_pages));
if (err < 0)
goto exit_free_pages;
if (prot->tail_size) {
sg_unmark_end(&sgout[pages]);
sg_set_buf(&sgout[pages + 1], &dctx->tail,
prot->tail_size);
sg_mark_end(&sgout[pages + 1]);
}
} else {
fallback_to_reg_recv:
sgout = sgin;
pages = 0;
darg->zc = false;
} else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
}
/* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
data_len + prot->tail_size, aead_req, darg);
if (darg->async)
return 0;
if (err)
goto exit_free_pages;
darg->skb = clear_skb ?: tls_strp_msg(ctx);
clear_skb = NULL;
if (unlikely(darg->async)) {
err = tls_strp_msg_hold(sk, skb, &ctx->async_hold);
if (err)
__skb_queue_tail(&ctx->async_hold, darg->skb);
return err;
}
if (prot->tail_size)
darg->tail = dctx->tail;
exit_free_pages:
/* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages]));
exit_free:
kfree(mem);
exit_free_skb:
consume_skb(clear_skb);
return err;
}
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct iov_iter *dest,
struct tls_decrypt_arg *darg)
static int
tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
struct tls_decrypt_arg *darg)
{
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int err;
if (tls_ctx->rx_conf != TLS_HW)
return 0;
err = tls_device_decrypted(sk, tls_ctx);
if (err <= 0)
return err;
darg->zc = false;
darg->async = false;
darg->skb = tls_strp_msg(ctx);
ctx->recv_pkt = NULL;
return 1;
}
static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
struct strp_msg *rxm;
int pad, err;
if (tlm->decrypted) {
darg->zc = false;
darg->async = false;
return 0;
}
err = tls_decrypt_device(sk, tls_ctx, darg);
if (err < 0)
return err;
if (err)
goto decrypt_done;
if (tls_ctx->rx_conf == TLS_HW) {
err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
if (err < 0)
return err;
if (err > 0) {
tlm->decrypted = 1;
darg->zc = false;
darg->async = false;
goto decrypt_done;
}
}
err = decrypt_internal(sk, skb, dest, NULL, darg);
err = tls_decrypt_sg(sk, dest, NULL, darg);
if (err < 0) {
if (err == -EBADMSG)
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
return err;
}
if (darg->async)
goto decrypt_next;
goto decrypt_done;
/* If opportunistic TLS 1.3 ZC failed retry without ZC */
if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
darg->tail != TLS_RECORD_TYPE_DATA)) {
@ -1599,30 +1638,33 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
if (!darg->tail)
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
return decrypt_skb_update(sk, skb, dest, darg);
return tls_rx_one_record(sk, dest, darg);
}
decrypt_done:
pad = tls_padding_length(prot, skb, darg);
if (pad < 0)
return pad;
if (darg->skb == ctx->recv_pkt)
ctx->recv_pkt = NULL;
pad = tls_padding_length(prot, darg->skb, darg);
if (pad < 0) {
consume_skb(darg->skb);
return pad;
}
rxm = strp_msg(darg->skb);
rxm->full_len -= pad;
rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size;
tlm->decrypted = 1;
decrypt_next:
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
return 0;
}
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout)
int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
{
struct tls_decrypt_arg darg = { .zc = true, };
return decrypt_internal(sk, skb, NULL, sgout, &darg);
return tls_decrypt_sg(sk, NULL, sgout, &darg);
}
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
@ -1648,6 +1690,13 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
return 1;
}
static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
{
consume_skb(ctx->recv_pkt);
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
}
/* This function traverses the rx_list in tls receive context to copies the
* decrypted records into the buffer provided by caller zero copy is not
* true. Further, the records are removed from the rx_list if it is not a peek
@ -1658,7 +1707,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
u8 *control,
size_t skip,
size_t len,
bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
@ -1692,12 +1740,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
if (err <= 0)
goto out;
if (!zc || (rxm->full_len - skip) > len) {
err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
goto out;
}
err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
goto out;
len = len - chunk;
copied = copied + chunk;
@ -1753,6 +1799,51 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
sk_flush_backlog(sk);
}
static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
bool nonblock)
{
long timeo;
lock_sock(sk);
timeo = sock_rcvtimeo(sk, nonblock);
while (unlikely(ctx->reader_present)) {
DEFINE_WAIT_FUNC(wait, woken_wake_function);
ctx->reader_contended = 1;
add_wait_queue(&ctx->wq, &wait);
sk_wait_event(sk, &timeo,
!READ_ONCE(ctx->reader_present), &wait);
remove_wait_queue(&ctx->wq, &wait);
if (!timeo)
return -EAGAIN;
if (signal_pending(current))
return sock_intr_errno(timeo);
}
WRITE_ONCE(ctx->reader_present, 1);
return timeo;
}
static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
{
if (unlikely(ctx->reader_contended)) {
if (wq_has_sleeper(&ctx->wq))
wake_up(&ctx->wq);
else
ctx->reader_contended = 0;
WARN_ON_ONCE(!ctx->reader_present);
}
WRITE_ONCE(ctx->reader_present, 0);
release_sock(sk);
}
int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg,
size_t len,
@ -1762,9 +1853,9 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
ssize_t decrypted = 0, async_copy_bytes = 0;
struct sk_psock *psock;
unsigned char control = 0;
ssize_t decrypted = 0;
size_t flushed_at = 0;
struct strp_msg *rxm;
struct tls_msg *tlm;
@ -1782,7 +1873,9 @@ int tls_sw_recvmsg(struct sock *sk,
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
psock = sk_psock_get(sk);
lock_sock(sk);
timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
if (timeo < 0)
return timeo;
bpf_strp_enabled = sk_psock_strp_enabled(psock);
/* If crypto failed the connection is broken */
@ -1791,7 +1884,7 @@ int tls_sw_recvmsg(struct sock *sk,
goto end;
/* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
if (err < 0)
goto end;
@ -1801,13 +1894,12 @@ int tls_sw_recvmsg(struct sock *sk,
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
len = len - copied;
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
ctx->zc_capable;
decrypted = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) {
struct tls_decrypt_arg darg = {};
struct tls_decrypt_arg darg;
int to_decrypt, chunk;
err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
@ -1815,15 +1907,19 @@ int tls_sw_recvmsg(struct sock *sk,
if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len,
flags);
if (chunk > 0)
goto leave_on_list;
if (chunk > 0) {
decrypted += chunk;
len -= chunk;
continue;
}
}
goto recv_end;
}
skb = ctx->recv_pkt;
rxm = strp_msg(skb);
tlm = tls_msg(skb);
memset(&darg.inargs, 0, sizeof(darg.inargs));
rxm = strp_msg(ctx->recv_pkt);
tlm = tls_msg(ctx->recv_pkt);
to_decrypt = rxm->full_len - prot->overhead_size;
@ -1837,12 +1933,16 @@ int tls_sw_recvmsg(struct sock *sk,
else
darg.async = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
err = tls_rx_one_record(sk, &msg->msg_iter, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto recv_end;
}
skb = darg.skb;
rxm = strp_msg(skb);
tlm = tls_msg(skb);
async |= darg.async;
/* If the type of records being processed is not known yet,
@ -1853,34 +1953,36 @@ int tls_sw_recvmsg(struct sock *sk,
* For tls1.3, we disable async.
*/
err = tls_record_content_type(msg, tlm, &control);
if (err <= 0)
if (err <= 0) {
tls_rx_rec_done(ctx);
put_on_rx_list_err:
__skb_queue_tail(&ctx->rx_list, skb);
goto recv_end;
}
/* periodically flush backlog, and feed strparser */
tls_read_flush_backlog(sk, prot, len, to_decrypt,
decrypted + copied, &flushed_at);
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
__skb_queue_tail(&ctx->rx_list, skb);
if (async) {
/* TLS 1.2-only, to_decrypt must be text length */
chunk = min_t(int, to_decrypt, len);
leave_on_list:
decrypted += chunk;
len -= chunk;
continue;
}
/* TLS 1.3 may have updated the length by more than overhead */
chunk = rxm->full_len;
tls_rx_rec_done(ctx);
if (!darg.zc) {
bool partially_consumed = chunk > len;
if (async) {
/* TLS 1.2-only, to_decrypt must be text len */
chunk = min_t(int, to_decrypt, len);
async_copy_bytes += chunk;
put_on_rx_list:
decrypted += chunk;
len -= chunk;
__skb_queue_tail(&ctx->rx_list, skb);
continue;
}
if (bpf_strp_enabled) {
/* BPF may try to queue the skb */
__skb_unlink(skb, &ctx->rx_list);
err = sk_psock_tls_strp_read(psock, skb);
if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len;
@ -1889,7 +1991,6 @@ leave_on_list:
consume_skb(skb);
continue;
}
__skb_queue_tail(&ctx->rx_list, skb);
}
if (partially_consumed)
@ -1898,22 +1999,21 @@ leave_on_list:
err = skb_copy_datagram_msg(skb, rxm->offset,
msg, chunk);
if (err < 0)
goto recv_end;
goto put_on_rx_list_err;
if (is_peek)
goto leave_on_list;
goto put_on_rx_list;
if (partially_consumed) {
rxm->offset += chunk;
rxm->full_len -= chunk;
goto leave_on_list;
goto put_on_rx_list;
}
}
decrypted += chunk;
len -= chunk;
__skb_unlink(skb, &ctx->rx_list);
consume_skb(skb);
/* Return full control message to userspace before trying
@ -1933,30 +2033,32 @@ recv_end:
reinit_completion(&ctx->async_wait.completion);
pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock);
if (pending) {
ret = 0;
if (pending)
ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
if (ret) {
if (err >= 0 || err == -EINPROGRESS)
err = ret;
decrypted = 0;
goto end;
}
__skb_queue_purge(&ctx->async_hold);
if (ret) {
if (err >= 0 || err == -EINPROGRESS)
err = ret;
decrypted = 0;
goto end;
}
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, &control, copied,
decrypted, false, is_peek);
decrypted, is_peek);
else
err = process_rx_list(ctx, msg, &control, 0,
decrypted, true, is_peek);
async_copy_bytes, is_peek);
decrypted = max(err, 0);
}
copied += decrypted;
end:
release_sock(sk);
tls_rx_reader_unlock(sk, ctx);
if (psock)
sk_psock_put(sk, psock);
return copied ? : err;
@ -1973,33 +2075,34 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_msg *tlm;
struct sk_buff *skb;
ssize_t copied = 0;
bool from_queue;
int err = 0;
long timeo;
int chunk;
lock_sock(sk);
timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
if (timeo < 0)
return timeo;
timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
from_queue = !skb_queue_empty(&ctx->rx_list);
if (from_queue) {
if (!skb_queue_empty(&ctx->rx_list)) {
skb = __skb_dequeue(&ctx->rx_list);
} else {
struct tls_decrypt_arg darg = {};
struct tls_decrypt_arg darg;
err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
timeo);
if (err <= 0)
goto splice_read_end;
skb = ctx->recv_pkt;
memset(&darg.inargs, 0, sizeof(darg.inargs));
err = decrypt_skb_update(sk, skb, NULL, &darg);
err = tls_rx_one_record(sk, NULL, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto splice_read_end;
}
tls_rx_rec_done(ctx);
skb = darg.skb;
}
rxm = strp_msg(skb);
@ -2008,29 +2111,29 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
/* splice does not support reading control messages */
if (tlm->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL;
goto splice_read_end;
goto splice_requeue;
}
chunk = min_t(unsigned int, rxm->full_len, len);
copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
if (copied < 0)
goto splice_read_end;
goto splice_requeue;
if (!from_queue) {
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
}
if (chunk < rxm->full_len) {
__skb_queue_head(&ctx->rx_list, skb);
rxm->offset += len;
rxm->full_len -= len;
} else {
consume_skb(skb);
goto splice_requeue;
}
consume_skb(skb);
splice_read_end:
release_sock(sk);
tls_rx_reader_unlock(sk, ctx);
return copied ? : err;
splice_requeue:
__skb_queue_head(&ctx->rx_list, skb);
goto splice_read_end;
}
bool tls_sw_sock_is_readable(struct sock *sk)
@ -2076,7 +2179,6 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
if (ret < 0)
goto read_failure;
tlm->decrypted = 0;
tlm->control = header[0];
data_len = ((header[4] & 0xFF) | (header[3] << 8));
@ -2371,9 +2473,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
} else {
crypto_init_wait(&sw_ctx_rx->async_wait);
spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
init_waitqueue_head(&sw_ctx_rx->wq);
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
skb_queue_head_init(&sw_ctx_rx->rx_list);
skb_queue_head_init(&sw_ctx_rx->async_hold);
aead = &sw_ctx_rx->aead_recv;
}