diff --git a/include/net/tls.h b/include/net/tls.h index df630f5fc723..bf9eb4823933 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -641,6 +641,7 @@ int tls_sw_fallback_init(struct sock *sk, #ifdef CONFIG_TLS_DEVICE void tls_device_init(void); void tls_device_cleanup(void); +void tls_device_sk_destruct(struct sock *sk); int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); void tls_device_free_resources_tx(struct sock *sk); int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); @@ -649,6 +650,14 @@ void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq); int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, struct sk_buff *skb, struct strp_msg *rxm); + +static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) +{ + if (!sk_fullsock(sk) || + smp_load_acquire(&sk->sk_destruct) != tls_device_sk_destruct) + return false; + return tls_get_ctx(sk)->rx_conf == TLS_HW; +} #else static inline void tls_device_init(void) {} static inline void tls_device_cleanup(void) {} diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index cd91ad812291..1ba5a92832bb 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -178,7 +178,7 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) * socket and no in-flight SKBs associated with this * socket, so it is safe to free all the resources. */ -static void tls_device_sk_destruct(struct sock *sk) +void tls_device_sk_destruct(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); @@ -196,6 +196,7 @@ static void tls_device_sk_destruct(struct sock *sk) if (refcount_dec_and_test(&tls_ctx->refcount)) tls_device_queue_ctx_destruction(tls_ctx); } +EXPORT_SYMBOL_GPL(tls_device_sk_destruct); void tls_device_free_resources_tx(struct sock *sk) { @@ -903,7 +904,7 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk, spin_unlock_irq(&tls_device_lock); ctx->sk_destruct = sk->sk_destruct; - sk->sk_destruct = tls_device_sk_destruct; + smp_store_release(&sk->sk_destruct, tls_device_sk_destruct); } }