!15996 Fix load failed while loading oversize checkpoint file with encyption

From: @liu_luobin
Reviewed-by: @zh_qh,@pkuliuliu
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-06 15:34:40 +08:00 committed by Gitee
commit 1a7200901f
3 changed files with 19 additions and 25 deletions

View File

@ -66,20 +66,20 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
Byte **cipher_data, int32_t *cipher_len) {
// Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
Byte buf[4];
memcpy(buf, encrypt_data, 4);
memcpy_s(buf, 4, encrypt_data, 4);
*iv_len = ByteToint(buf);
memcpy(buf, encrypt_data + *iv_len + 4, 4);
memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
*cipher_len = ByteToint(buf);
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false;
}
*iv = new Byte[*iv_len];
memcpy(*iv, encrypt_data + 4, *iv_len);
memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
*cipher_data = new Byte[*cipher_len];
memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len);
memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
return true;
}
@ -152,18 +152,15 @@ EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key,
bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
const int32_t key_len, const std::string &enc_mode) {
// Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length +
// iv length + iv + plain text length + cipher text length + cipher text"
int32_t cipher_len = 0; // cipher length
int32_t cipher_len = 0;
int32_t iv_len = AES_BLOCK_SIZE;
Byte *iv = new Byte[iv_len];
RAND_bytes(iv, sizeof(Byte) * iv_len);
Byte *iv_cpy = new Byte[16];
memcpy(iv_cpy, iv, 16);
memcpy_s(iv_cpy, 16, iv, 16);
// set the encryption length
int32_t ret = 0;
int32_t flen = 0;
std::string alg_mode;
@ -193,7 +190,7 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da
EVP_CIPHER_CTX_free(ctx);
int64_t cur = 0;
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
cur += 4;
@ -212,8 +209,6 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da
bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
const int32_t key_len, const std::string &dec_mode) {
// Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv +
// plain text data length + cipher text data length + cipher text data"
std::string alg_mode;
std::string work_mode;
@ -221,7 +216,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
return false;
}
// 解析加密数据
int32_t iv_len = 0;
int32_t cipher_len = 0;
Byte *iv = NULL;
@ -236,7 +230,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
return false;
}
// 解密密文
int ret = 0;
int mlen = 0;
@ -276,7 +269,7 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
*encrypt_len = 0;
while (cur_pos < plain_len) {
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
memcpy(block_buf, plain_data + cur_pos, cur_block_size);
memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);
if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
delete[] block_buf;
@ -284,9 +277,9 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
delete[] encrypt_data;
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
}
memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
*encrypt_len += sizeof(int32_t);
memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len);
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
*encrypt_len += block_enc_len;
cur_pos += cur_block_size;
}
@ -300,7 +293,6 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
Byte *decrypt_data = nullptr;
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
char *int_buf = new char[4];
// Byte *decrypt_block_buf = new Byte[100];
Byte *decrypt_block_buf = nullptr;
int32_t decrypt_block_len;
@ -325,7 +317,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
}
fid.read(int_buf, sizeof(int32_t));
int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
fid.read(block_buf, sizeof(char) * block_size);
if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
key_len, dec_mode))) {
@ -334,7 +326,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
delete[] decrypt_data;
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
}
memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len);
memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len);
*decrypt_len += decrypt_block_len;
}
fid.close();

View File

@ -33,7 +33,7 @@ typedef unsigned char Byte;
namespace mindspore {
namespace crypto {
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,

View File

@ -168,10 +168,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
f.write(checkpoint_list.SerializeToString())
else:
plain_data += checkpoint_list.SerializeToString()
while len(plain_data) >= SLICE_SIZE * 1024:
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key,
max_block_size = SLICE_SIZE*1024
while len(plain_data) >= max_block_size:
cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key,
len(enc_key), enc_mode)
plain_data = plain_data[SLICE_SIZE*1024:]
plain_data = plain_data[max_block_size:]
if enc_key is not None:
if plain_data: