diff --git a/mindspore/ccsrc/crypto/crypto.cc b/mindspore/ccsrc/crypto/crypto.cc index 748f73a6535..de66ff91d36 100644 --- a/mindspore/ccsrc/crypto/crypto.cc +++ b/mindspore/ccsrc/crypto/crypto.cc @@ -66,20 +66,20 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte * bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, Byte **cipher_data, int32_t *cipher_len) { - // Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data + // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data Byte buf[4]; - memcpy(buf, encrypt_data, 4); + memcpy_s(buf, 4, encrypt_data, 4); *iv_len = ByteToint(buf); - memcpy(buf, encrypt_data + *iv_len + 4, 4); + memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4); *cipher_len = ByteToint(buf); if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { MS_LOG(ERROR) << "Failed to parse encrypt data."; return false; } *iv = new Byte[*iv_len]; - memcpy(*iv, encrypt_data + 4, *iv_len); + memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len); *cipher_data = new Byte[*cipher_len]; - memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len); + memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len); return true; } @@ -152,18 +152,15 @@ EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) { - // Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length + - // iv length + iv + plain text length + cipher text length + cipher text" - int32_t cipher_len = 0; // cipher length + int32_t cipher_len = 0; int32_t iv_len = AES_BLOCK_SIZE; Byte *iv = new Byte[iv_len]; RAND_bytes(iv, sizeof(Byte) * iv_len); Byte *iv_cpy = new Byte[16]; - memcpy(iv_cpy, iv, 16); + memcpy_s(iv_cpy, 16, iv, 16); - // set the encryption length int32_t ret = 0; int32_t flen = 0; std::string alg_mode; @@ -193,7 +190,7 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da EVP_CIPHER_CTX_free(ctx); int64_t cur = 0; - *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接 + *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); cur += 4; @@ -212,8 +209,6 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, const int32_t key_len, const std::string &dec_mode) { - // Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv + - // plain text data length + cipher text data length + cipher text data" std::string alg_mode; std::string work_mode; @@ -221,7 +216,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co return false; } - // 解析加密数据 int32_t iv_len = 0; int32_t cipher_len = 0; Byte *iv = NULL; @@ -236,7 +230,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co return false; } - // 解密密文 int ret = 0; int mlen = 0; @@ -276,7 +269,7 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B *encrypt_len = 0; while (cur_pos < plain_len) { int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); - memcpy(block_buf, plain_data + cur_pos, cur_block_size); + memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size); if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { delete[] block_buf; @@ -284,9 +277,9 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B delete[] encrypt_data; MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; } - memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); + memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); *encrypt_len += sizeof(int32_t); - memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len); + memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len); *encrypt_len += block_enc_len; cur_pos += cur_block_size; } @@ -300,7 +293,6 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte * Byte *decrypt_data = nullptr; char *block_buf = new char[MAX_BLOCK_SIZE * 2]; char *int_buf = new char[4]; - // Byte *decrypt_block_buf = new Byte[100]; Byte *decrypt_block_buf = nullptr; int32_t decrypt_block_len; @@ -325,7 +317,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte * } fid.read(int_buf, sizeof(int32_t)); - int64_t block_size = ByteToint(reinterpret_cast(int_buf)); + int32_t block_size = ByteToint(reinterpret_cast(int_buf)); fid.read(block_buf, sizeof(char) * block_size); if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast(block_buf), block_size, key, key_len, dec_mode))) { @@ -334,7 +326,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte * delete[] decrypt_data; MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; } - memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len); + memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len); *decrypt_len += decrypt_block_len; } fid.close(); diff --git a/mindspore/ccsrc/crypto/crypto.h b/mindspore/ccsrc/crypto/crypto.h index edf7e8176b7..29d3b14f08e 100644 --- a/mindspore/ccsrc/crypto/crypto.h +++ b/mindspore/ccsrc/crypto/crypto.h @@ -33,7 +33,7 @@ typedef unsigned char Byte; namespace mindspore { namespace crypto { -const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB +const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index b4906171d5a..d069277335b 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -168,10 +168,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): f.write(checkpoint_list.SerializeToString()) else: plain_data += checkpoint_list.SerializeToString() - while len(plain_data) >= SLICE_SIZE * 1024: - cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key, + + max_block_size = SLICE_SIZE*1024 + while len(plain_data) >= max_block_size: + cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key, len(enc_key), enc_mode) - plain_data = plain_data[SLICE_SIZE*1024:] + plain_data = plain_data[max_block_size:] if enc_key is not None: if plain_data: