diff --git a/drivers/crypto/padlock-aes.c b/drivers/crypto/padlock-aes.c index bf2917d197a0..856b3cc25583 100644 --- a/drivers/crypto/padlock-aes.c +++ b/drivers/crypto/padlock-aes.c @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include #include #include "padlock.h" @@ -49,6 +51,8 @@ struct aes_ctx { u32 *D; }; +static DEFINE_PER_CPU(struct cword *, last_cword); + /* Tells whether the ACE is capable to generate the extended key for a given key_len. */ static inline int @@ -89,6 +93,7 @@ static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key, const __le32 *key = (const __le32 *)in_key; u32 *flags = &tfm->crt_flags; struct crypto_aes_ctx gen_aes; + int cpu; if (key_len % 8) { *flags |= CRYPTO_TFM_RES_BAD_KEY_LEN; @@ -118,7 +123,7 @@ static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key, /* Don't generate extended keys if the hardware can do it. */ if (aes_hw_extkey_available(key_len)) - return 0; + goto ok; ctx->D = ctx->d_data; ctx->cword.encrypt.keygen = 1; @@ -131,15 +136,30 @@ static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key, memcpy(ctx->E, gen_aes.key_enc, AES_MAX_KEYLENGTH); memcpy(ctx->D, gen_aes.key_dec, AES_MAX_KEYLENGTH); + +ok: + for_each_online_cpu(cpu) + if (&ctx->cword.encrypt == per_cpu(last_cword, cpu) || + &ctx->cword.decrypt == per_cpu(last_cword, cpu)) + per_cpu(last_cword, cpu) = NULL; + return 0; } /* ====== Encryption/decryption routines ====== */ /* These are the real call to PadLock. */ -static inline void padlock_reset_key(void) +static inline void padlock_reset_key(struct cword *cword) { - asm volatile ("pushfl; popfl"); + int cpu = raw_smp_processor_id(); + + if (cword != per_cpu(last_cword, cpu)) + asm volatile ("pushfl; popfl"); +} + +static inline void padlock_store_cword(struct cword *cword) +{ + per_cpu(last_cword, raw_smp_processor_id()) = cword; } /* @@ -149,7 +169,7 @@ static inline void padlock_reset_key(void) */ static inline void padlock_xcrypt(const u8 *input, u8 *output, void *key, - void *control_word) + struct cword *control_word) { asm volatile (".byte 0xf3,0x0f,0xa7,0xc8" /* rep xcryptecb */ : "+S"(input), "+D"(output) @@ -213,22 +233,24 @@ static void aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in) { struct aes_ctx *ctx = aes_ctx(tfm); int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.encrypt); ts_state = irq_ts_save(); aes_crypt(in, out, ctx->E, &ctx->cword.encrypt); irq_ts_restore(ts_state); + padlock_store_cword(&ctx->cword.encrypt); } static void aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in) { struct aes_ctx *ctx = aes_ctx(tfm); int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.encrypt); ts_state = irq_ts_save(); aes_crypt(in, out, ctx->D, &ctx->cword.decrypt); irq_ts_restore(ts_state); + padlock_store_cword(&ctx->cword.encrypt); } static struct crypto_alg aes_alg = { @@ -261,7 +283,7 @@ static int ecb_aes_encrypt(struct blkcipher_desc *desc, int err; int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.encrypt); blkcipher_walk_init(&walk, dst, src, nbytes); err = blkcipher_walk_virt(desc, &walk); @@ -276,6 +298,8 @@ static int ecb_aes_encrypt(struct blkcipher_desc *desc, } irq_ts_restore(ts_state); + padlock_store_cword(&ctx->cword.encrypt); + return err; } @@ -288,7 +312,7 @@ static int ecb_aes_decrypt(struct blkcipher_desc *desc, int err; int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.decrypt); blkcipher_walk_init(&walk, dst, src, nbytes); err = blkcipher_walk_virt(desc, &walk); @@ -302,6 +326,9 @@ static int ecb_aes_decrypt(struct blkcipher_desc *desc, err = blkcipher_walk_done(desc, &walk, nbytes); } irq_ts_restore(ts_state); + + padlock_store_cword(&ctx->cword.encrypt); + return err; } @@ -336,7 +363,7 @@ static int cbc_aes_encrypt(struct blkcipher_desc *desc, int err; int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.encrypt); blkcipher_walk_init(&walk, dst, src, nbytes); err = blkcipher_walk_virt(desc, &walk); @@ -353,6 +380,8 @@ static int cbc_aes_encrypt(struct blkcipher_desc *desc, } irq_ts_restore(ts_state); + padlock_store_cword(&ctx->cword.decrypt); + return err; } @@ -365,7 +394,7 @@ static int cbc_aes_decrypt(struct blkcipher_desc *desc, int err; int ts_state; - padlock_reset_key(); + padlock_reset_key(&ctx->cword.encrypt); blkcipher_walk_init(&walk, dst, src, nbytes); err = blkcipher_walk_virt(desc, &walk); @@ -380,6 +409,9 @@ static int cbc_aes_decrypt(struct blkcipher_desc *desc, } irq_ts_restore(ts_state); + + padlock_store_cword(&ctx->cword.encrypt); + return err; }