diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index 0fdbd1c45977..bda26405497c 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -569,10 +569,8 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
 							 private_key);
 		list_for_each_entry_safe(peer, temp, &wg->peer_list,
 					 peer_list) {
-			if (wg_noise_precompute_static_static(peer))
-				wg_noise_expire_current_peer_keypairs(peer);
-			else
-				wg_peer_remove(peer);
+			BUG_ON(!wg_noise_precompute_static_static(peer));
+			wg_noise_expire_current_peer_keypairs(peer);
 		}
 		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
 		up_write(&wg->static_identity.lock);
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index d71c8db68a8c..919d9d866446 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -46,17 +46,21 @@ void __init wg_noise_init(void)
 /* Must hold peer->handshake.static_identity->lock */
 bool wg_noise_precompute_static_static(struct wg_peer *peer)
 {
-	bool ret = true;
+	bool ret;
 
 	down_write(&peer->handshake.lock);
-	if (peer->handshake.static_identity->has_identity)
+	if (peer->handshake.static_identity->has_identity) {
 		ret = curve25519(
 			peer->handshake.precomputed_static_static,
 			peer->handshake.static_identity->static_private,
 			peer->handshake.remote_static);
-	else
+	} else {
+		u8 empty[NOISE_PUBLIC_KEY_LEN] = { 0 };
+
+		ret = curve25519(empty, empty, peer->handshake.remote_static);
 		memset(peer->handshake.precomputed_static_static, 0,
 		       NOISE_PUBLIC_KEY_LEN);
+	}
 	up_write(&peer->handshake.lock);
 	return ret;
 }