forked from mindspore-Ecosystem/mindspore
!15996 Fix load failed while loading oversize checkpoint file with encyption
From: @liu_luobin Reviewed-by: @zh_qh,@pkuliuliu Signed-off-by: @zh_qh
This commit is contained in:
commit
1a7200901f
|
@ -66,20 +66,20 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
|
|||
|
||||
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
|
||||
Byte **cipher_data, int32_t *cipher_len) {
|
||||
// Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||
Byte buf[4];
|
||||
memcpy(buf, encrypt_data, 4);
|
||||
memcpy_s(buf, 4, encrypt_data, 4);
|
||||
*iv_len = ByteToint(buf);
|
||||
memcpy(buf, encrypt_data + *iv_len + 4, 4);
|
||||
memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
|
||||
*cipher_len = ByteToint(buf);
|
||||
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
|
||||
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||
return false;
|
||||
}
|
||||
*iv = new Byte[*iv_len];
|
||||
memcpy(*iv, encrypt_data + 4, *iv_len);
|
||||
memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
|
||||
*cipher_data = new Byte[*cipher_len];
|
||||
memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len);
|
||||
memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -152,18 +152,15 @@ EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key,
|
|||
|
||||
bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
|
||||
const int32_t key_len, const std::string &enc_mode) {
|
||||
// Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length +
|
||||
// iv length + iv + plain text length + cipher text length + cipher text"
|
||||
int32_t cipher_len = 0; // cipher length
|
||||
int32_t cipher_len = 0;
|
||||
|
||||
int32_t iv_len = AES_BLOCK_SIZE;
|
||||
Byte *iv = new Byte[iv_len];
|
||||
RAND_bytes(iv, sizeof(Byte) * iv_len);
|
||||
|
||||
Byte *iv_cpy = new Byte[16];
|
||||
memcpy(iv_cpy, iv, 16);
|
||||
memcpy_s(iv_cpy, 16, iv, 16);
|
||||
|
||||
// set the encryption length
|
||||
int32_t ret = 0;
|
||||
int32_t flen = 0;
|
||||
std::string alg_mode;
|
||||
|
@ -193,7 +190,7 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da
|
|||
EVP_CIPHER_CTX_free(ctx);
|
||||
|
||||
int64_t cur = 0;
|
||||
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接
|
||||
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
|
||||
|
||||
memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
|
||||
cur += 4;
|
||||
|
@ -212,8 +209,6 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da
|
|||
|
||||
bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
|
||||
const int32_t key_len, const std::string &dec_mode) {
|
||||
// Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv +
|
||||
// plain text data length + cipher text data length + cipher text data"
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
|
||||
|
@ -221,7 +216,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
|
|||
return false;
|
||||
}
|
||||
|
||||
// 解析加密数据
|
||||
int32_t iv_len = 0;
|
||||
int32_t cipher_len = 0;
|
||||
Byte *iv = NULL;
|
||||
|
@ -236,7 +230,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
|
|||
return false;
|
||||
}
|
||||
|
||||
// 解密密文
|
||||
int ret = 0;
|
||||
int mlen = 0;
|
||||
|
||||
|
@ -276,7 +269,7 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
|
|||
*encrypt_len = 0;
|
||||
while (cur_pos < plain_len) {
|
||||
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
|
||||
memcpy(block_buf, plain_data + cur_pos, cur_block_size);
|
||||
memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);
|
||||
|
||||
if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
|
||||
delete[] block_buf;
|
||||
|
@ -284,9 +277,9 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
|
|||
delete[] encrypt_data;
|
||||
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||
}
|
||||
memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
|
||||
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
|
||||
*encrypt_len += sizeof(int32_t);
|
||||
memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len);
|
||||
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
|
||||
*encrypt_len += block_enc_len;
|
||||
cur_pos += cur_block_size;
|
||||
}
|
||||
|
@ -300,7 +293,6 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
|
|||
Byte *decrypt_data = nullptr;
|
||||
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
|
||||
char *int_buf = new char[4];
|
||||
// Byte *decrypt_block_buf = new Byte[100];
|
||||
Byte *decrypt_block_buf = nullptr;
|
||||
int32_t decrypt_block_len;
|
||||
|
||||
|
@ -325,7 +317,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
|
|||
}
|
||||
fid.read(int_buf, sizeof(int32_t));
|
||||
|
||||
int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||
fid.read(block_buf, sizeof(char) * block_size);
|
||||
if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
|
||||
key_len, dec_mode))) {
|
||||
|
@ -334,7 +326,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
|
|||
delete[] decrypt_data;
|
||||
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
}
|
||||
memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len);
|
||||
memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len);
|
||||
*decrypt_len += decrypt_block_len;
|
||||
}
|
||||
fid.close();
|
||||
|
|
|
@ -33,7 +33,7 @@ typedef unsigned char Byte;
|
|||
|
||||
namespace mindspore {
|
||||
namespace crypto {
|
||||
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB
|
||||
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||
|
||||
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||
|
|
|
@ -168,10 +168,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
f.write(checkpoint_list.SerializeToString())
|
||||
else:
|
||||
plain_data += checkpoint_list.SerializeToString()
|
||||
while len(plain_data) >= SLICE_SIZE * 1024:
|
||||
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key,
|
||||
|
||||
max_block_size = SLICE_SIZE*1024
|
||||
while len(plain_data) >= max_block_size:
|
||||
cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key,
|
||||
len(enc_key), enc_mode)
|
||||
plain_data = plain_data[SLICE_SIZE*1024:]
|
||||
plain_data = plain_data[max_block_size:]
|
||||
|
||||
if enc_key is not None:
|
||||
if plain_data:
|
||||
|
|
Loading…
Reference in New Issue