forked from mindspore-Ecosystem/mindspore
!15744 Add encryption for ckpt file
From: @liu_luobin Reviewed-by: @baochong,@chongbao Signed-off-by:
This commit is contained in:
commit
33966f0cd5
|
@ -5,10 +5,14 @@ else()
|
||||||
set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz")
|
set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz")
|
||||||
set(MD5 "bdd51a68ad74618dd2519da8e0bcc759")
|
set(MD5 "bdd51a68ad74618dd2519da8e0bcc759")
|
||||||
endif()
|
endif()
|
||||||
mindspore_add_pkg(openssl
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
VER 1.1.0
|
mindspore_add_pkg(openssl
|
||||||
LIBS ssl crypto
|
VER 1.1.0
|
||||||
URL ${REQ_URL}
|
LIBS ssl crypto
|
||||||
MD5 ${MD5}
|
URL ${REQ_URL}
|
||||||
CONFIGURE_COMMAND ./config no-zlib no-shared)
|
MD5 ${MD5}
|
||||||
include_directories(${openssl_INC})
|
CONFIGURE_COMMAND ./config no-zlib no-shared)
|
||||||
|
include_directories(${openssl_INC})
|
||||||
|
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||||
|
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||||
|
endif()
|
||||||
|
|
|
@ -226,6 +226,7 @@ set(SUB_COMP
|
||||||
pipeline/jit
|
pipeline/jit
|
||||||
pipeline/pynative
|
pipeline/pynative
|
||||||
common debug pybind_api utils vm profiler ps
|
common debug pybind_api utils vm profiler ps
|
||||||
|
crypto
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(_comp ${SUB_COMP})
|
foreach(_comp ${SUB_COMP})
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES})
|
||||||
|
|
||||||
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||||
|
target_link_libraries(_mindspore_crypto_obj mindspore::crypto)
|
||||||
|
endif()
|
|
@ -0,0 +1,347 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "crypto/crypto.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace crypto {
|
||||||
|
int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; }
|
||||||
|
|
||||||
|
Byte *intToByte(const int32_t &n) {
|
||||||
|
Byte *byte = new Byte[4];
|
||||||
|
memset(byte, 0, sizeof(Byte) * 4);
|
||||||
|
byte[0] = (Byte)(0xFF & n);
|
||||||
|
byte[1] = (Byte)((0xFF00 & n) >> 8);
|
||||||
|
byte[2] = (Byte)((0xFF0000 & n) >> 16);
|
||||||
|
byte[3] = (Byte)((0xFF000000 & n) >> 24);
|
||||||
|
return byte;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t ByteToint(const Byte *byteArray) {
|
||||||
|
int32_t res = byteArray[0] & 0xFF;
|
||||||
|
res |= ((byteArray[1] << 8) & 0xFF00);
|
||||||
|
res |= ((byteArray[2] << 16) & 0xFF0000);
|
||||||
|
res += ((byteArray[3] << 24) & 0xFF000000);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsCipherFile(std::string file_path) {
|
||||||
|
char *int_buf = new char[4];
|
||||||
|
int flag = 0;
|
||||||
|
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
|
||||||
|
if (!fid) {
|
||||||
|
MS_LOG(ERROR) << "Open file failed";
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
fid.read(int_buf, sizeof(int32_t));
|
||||||
|
fid.close();
|
||||||
|
flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||||
|
delete[] int_buf;
|
||||||
|
return flag == MAGIC_NUM;
|
||||||
|
}
|
||||||
|
#if defined(_WIN32)
|
||||||
|
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &enc_mode) {
|
||||||
|
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &dec_mode) {
|
||||||
|
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
|
||||||
|
Byte **cipher_data, int32_t *cipher_len) {
|
||||||
|
// Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||||
|
Byte buf[4];
|
||||||
|
memcpy(buf, encrypt_data, 4);
|
||||||
|
*iv_len = ByteToint(buf);
|
||||||
|
memcpy(buf, encrypt_data + *iv_len + 4, 4);
|
||||||
|
*cipher_len = ByteToint(buf);
|
||||||
|
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
|
||||||
|
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*iv = new Byte[*iv_len];
|
||||||
|
memcpy(*iv, encrypt_data + 4, *iv_len);
|
||||||
|
*cipher_data = new Byte[*cipher_len];
|
||||||
|
memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
|
||||||
|
std::smatch results;
|
||||||
|
std::regex re("([A-Z]{3})-([A-Z]{3})");
|
||||||
|
if (!std::regex_match(mode.c_str(), re)) {
|
||||||
|
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::regex_search(mode, results, re);
|
||||||
|
*alg_mode = results[1];
|
||||||
|
*work_mode = results[2];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
|
||||||
|
int flag) {
|
||||||
|
int ret = 0;
|
||||||
|
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||||
|
if (work_mode != "GCM" && work_mode != "CBC") {
|
||||||
|
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const EVP_CIPHER *(*funcPtr)() = nullptr;
|
||||||
|
if (work_mode == "GCM") {
|
||||||
|
switch (key_len) {
|
||||||
|
case 16:
|
||||||
|
funcPtr = EVP_aes_128_gcm;
|
||||||
|
break;
|
||||||
|
case 24:
|
||||||
|
funcPtr = EVP_aes_192_gcm;
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
funcPtr = EVP_aes_256_gcm;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||||
|
}
|
||||||
|
} else if (work_mode == "CBC") {
|
||||||
|
switch (key_len) {
|
||||||
|
case 16:
|
||||||
|
funcPtr = EVP_aes_128_cbc;
|
||||||
|
break;
|
||||||
|
case 24:
|
||||||
|
funcPtr = EVP_aes_192_cbc;
|
||||||
|
break;
|
||||||
|
case 32:
|
||||||
|
funcPtr = EVP_aes_256_cbc;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (flag == 0) {
|
||||||
|
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||||
|
} else if (flag == 1) {
|
||||||
|
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
|
||||||
|
const int32_t key_len, const std::string &enc_mode) {
|
||||||
|
// Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length +
|
||||||
|
// iv length + iv + plain text length + cipher text length + cipher text"
|
||||||
|
int32_t cipher_len = 0; // cipher length
|
||||||
|
|
||||||
|
int32_t iv_len = AES_BLOCK_SIZE;
|
||||||
|
Byte *iv = new Byte[iv_len];
|
||||||
|
RAND_bytes(iv, sizeof(Byte) * iv_len);
|
||||||
|
|
||||||
|
Byte *iv_cpy = new Byte[16];
|
||||||
|
memcpy(iv_cpy, iv, 16);
|
||||||
|
|
||||||
|
// set the encryption length
|
||||||
|
int32_t ret = 0;
|
||||||
|
int32_t flen = 0;
|
||||||
|
std::string alg_mode;
|
||||||
|
std::string work_mode;
|
||||||
|
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte *cipher_data;
|
||||||
|
cipher_data = new Byte[plain_len + 16];
|
||||||
|
ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
||||||
|
delete[] cipher_data;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (work_mode == "CBC") {
|
||||||
|
EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen);
|
||||||
|
cipher_len += flen;
|
||||||
|
}
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
|
||||||
|
int64_t cur = 0;
|
||||||
|
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接
|
||||||
|
|
||||||
|
memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
|
||||||
|
cur += 4;
|
||||||
|
memcpy(encrypt_data + cur, intToByte(iv_len), 4);
|
||||||
|
cur += 4;
|
||||||
|
memcpy(encrypt_data + cur, iv_cpy, iv_len);
|
||||||
|
cur += iv_len;
|
||||||
|
memcpy(encrypt_data + cur, intToByte(cipher_len), 4);
|
||||||
|
cur += 4;
|
||||||
|
memcpy(encrypt_data + cur, cipher_data, cipher_len);
|
||||||
|
*encrypt_data_len += 4;
|
||||||
|
|
||||||
|
delete[] cipher_data;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
|
||||||
|
const int32_t key_len, const std::string &dec_mode) {
|
||||||
|
// Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv +
|
||||||
|
// plain text data length + cipher text data length + cipher text data"
|
||||||
|
std::string alg_mode;
|
||||||
|
std::string work_mode;
|
||||||
|
|
||||||
|
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析加密数据
|
||||||
|
int32_t iv_len = 0;
|
||||||
|
int32_t cipher_len = 0;
|
||||||
|
Byte *iv = NULL;
|
||||||
|
Byte *cipher_data = NULL;
|
||||||
|
|
||||||
|
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*plain_data = new Byte[cipher_len + 16];
|
||||||
|
if (*plain_data == NULL) {
|
||||||
|
MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解密密文
|
||||||
|
int ret = 0;
|
||||||
|
int mlen = 0;
|
||||||
|
|
||||||
|
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1);
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (work_mode == "CBC") {
|
||||||
|
ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*plain_len += mlen;
|
||||||
|
}
|
||||||
|
delete[] iv;
|
||||||
|
delete[] cipher_data;
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &enc_mode) {
|
||||||
|
int64_t cur_pos = 0;
|
||||||
|
int64_t block_enc_len = 0;
|
||||||
|
int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
|
||||||
|
Byte *encrypt_data = new Byte[encrypt_buf_len];
|
||||||
|
Byte *block_buf = new Byte[MAX_BLOCK_SIZE];
|
||||||
|
Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100];
|
||||||
|
|
||||||
|
*encrypt_len = 0;
|
||||||
|
while (cur_pos < plain_len) {
|
||||||
|
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
|
||||||
|
memcpy(block_buf, plain_data + cur_pos, cur_block_size);
|
||||||
|
|
||||||
|
if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
|
||||||
|
delete[] block_buf;
|
||||||
|
delete[] block_enc_buf;
|
||||||
|
delete[] encrypt_data;
|
||||||
|
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||||
|
}
|
||||||
|
memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
|
||||||
|
*encrypt_len += sizeof(int32_t);
|
||||||
|
memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len);
|
||||||
|
*encrypt_len += block_enc_len;
|
||||||
|
cur_pos += cur_block_size;
|
||||||
|
}
|
||||||
|
delete[] block_buf;
|
||||||
|
delete[] block_enc_buf;
|
||||||
|
return encrypt_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &dec_mode) {
|
||||||
|
Byte *decrypt_data = nullptr;
|
||||||
|
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
|
||||||
|
char *int_buf = new char[4];
|
||||||
|
// Byte *decrypt_block_buf = new Byte[100];
|
||||||
|
Byte *decrypt_block_buf = nullptr;
|
||||||
|
int32_t decrypt_block_len;
|
||||||
|
|
||||||
|
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
|
||||||
|
if (!fid) {
|
||||||
|
MS_LOG(ERROR) << "Open file failed";
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
fid.seekg(0, std::ios_base::end);
|
||||||
|
int64_t file_size = fid.tellg();
|
||||||
|
fid.clear();
|
||||||
|
fid.seekg(0);
|
||||||
|
decrypt_data = new Byte[file_size];
|
||||||
|
|
||||||
|
*decrypt_len = 0;
|
||||||
|
while (fid.tellg() < file_size) {
|
||||||
|
fid.read(int_buf, sizeof(int32_t));
|
||||||
|
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||||
|
if (cipher_flag != MAGIC_NUM) {
|
||||||
|
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
|
||||||
|
<< "\"is not an encrypted file and cannot be decrypted";
|
||||||
|
}
|
||||||
|
fid.read(int_buf, sizeof(int32_t));
|
||||||
|
|
||||||
|
int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||||
|
fid.read(block_buf, sizeof(char) * block_size);
|
||||||
|
if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
|
||||||
|
key_len, dec_mode))) {
|
||||||
|
delete[] block_buf;
|
||||||
|
delete[] int_buf;
|
||||||
|
delete[] decrypt_data;
|
||||||
|
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||||
|
}
|
||||||
|
memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len);
|
||||||
|
*decrypt_len += decrypt_block_len;
|
||||||
|
}
|
||||||
|
fid.close();
|
||||||
|
delete[] block_buf;
|
||||||
|
delete[] int_buf;
|
||||||
|
return decrypt_data;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace crypto
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
|
||||||
|
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
|
||||||
|
|
||||||
|
#if not defined(_WIN32)
|
||||||
|
#include <openssl/aes.h>
|
||||||
|
#include <openssl/evp.h>
|
||||||
|
#include <openssl/rand.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <regex>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
typedef unsigned char Byte;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace crypto {
|
||||||
|
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB
|
||||||
|
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||||
|
|
||||||
|
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &enc_mode);
|
||||||
|
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||||
|
const std::string &dec_mode);
|
||||||
|
bool IsCipherFile(const std::string file_path);
|
||||||
|
} // namespace crypto
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif
|
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "crypto/crypto_pybind.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace crypto {
|
||||||
|
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) {
|
||||||
|
int64_t encrypt_len;
|
||||||
|
char *encrypt_data;
|
||||||
|
encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
|
||||||
|
reinterpret_cast<Byte *>(key), key_len, enc_mode));
|
||||||
|
return py::bytes(encrypt_data, encrypt_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) {
|
||||||
|
int64_t decrypt_len;
|
||||||
|
char *decrypt_data;
|
||||||
|
decrypt_data = reinterpret_cast<char *>(
|
||||||
|
Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode));
|
||||||
|
return py::bytes(decrypt_data, decrypt_len);
|
||||||
|
}
|
||||||
|
bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); }
|
||||||
|
} // namespace crypto
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
|
||||||
|
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
|
||||||
|
#include "crypto/crypto.h"
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace crypto {
|
||||||
|
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode);
|
||||||
|
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode);
|
||||||
|
bool PyIsCipherFile(std::string file_path);
|
||||||
|
} // namespace crypto
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif
|
|
@ -28,6 +28,7 @@
|
||||||
#include "utils/mpi/mpi_config.h"
|
#include "utils/mpi/mpi_config.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
#include "frontend/parallel/costmodel_context.h"
|
#include "frontend/parallel/costmodel_context.h"
|
||||||
|
#include "crypto/crypto_pybind.h"
|
||||||
#ifdef ENABLE_GPU_COLLECTIVE
|
#ifdef ENABLE_GPU_COLLECTIVE
|
||||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||||
#else
|
#else
|
||||||
|
@ -330,4 +331,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
|
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
|
||||||
|
|
||||||
|
(void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data.");
|
||||||
|
(void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data.");
|
||||||
|
(void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,6 +82,10 @@ class CheckpointConfig:
|
||||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
|
||||||
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
||||||
with the network in training, the initial value of saved_network will be saved. Default: None.
|
with the network in training, the initial value of saved_network will be saved. Default: None.
|
||||||
|
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
||||||
|
is not required. Default: None.
|
||||||
|
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
||||||
|
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input_param is None or 0.
|
ValueError: If the input_param is None or 0.
|
||||||
|
@ -126,7 +130,9 @@ class CheckpointConfig:
|
||||||
keep_checkpoint_per_n_minutes=0,
|
keep_checkpoint_per_n_minutes=0,
|
||||||
integrated_save=True,
|
integrated_save=True,
|
||||||
async_save=False,
|
async_save=False,
|
||||||
saved_network=None):
|
saved_network=None,
|
||||||
|
enc_key=None,
|
||||||
|
enc_mode='AES-GCM'):
|
||||||
|
|
||||||
if save_checkpoint_steps is not None:
|
if save_checkpoint_steps is not None:
|
||||||
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
|
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
|
||||||
|
@ -160,6 +166,8 @@ class CheckpointConfig:
|
||||||
self._integrated_save = Validator.check_bool(integrated_save)
|
self._integrated_save = Validator.check_bool(integrated_save)
|
||||||
self._async_save = Validator.check_bool(async_save)
|
self._async_save = Validator.check_bool(async_save)
|
||||||
self._saved_network = saved_network
|
self._saved_network = saved_network
|
||||||
|
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
||||||
|
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def save_checkpoint_steps(self):
|
def save_checkpoint_steps(self):
|
||||||
|
@ -196,6 +204,16 @@ class CheckpointConfig:
|
||||||
"""Get the value of _saved_network"""
|
"""Get the value of _saved_network"""
|
||||||
return self._saved_network
|
return self._saved_network
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enc_key(self):
|
||||||
|
"""Get the value of _enc_key"""
|
||||||
|
return self._enc_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enc_mode(self):
|
||||||
|
"""Get the value of _enc_mode"""
|
||||||
|
return self._enc_mode
|
||||||
|
|
||||||
def get_checkpoint_policy(self):
|
def get_checkpoint_policy(self):
|
||||||
"""Get the policy of checkpoint."""
|
"""Get the policy of checkpoint."""
|
||||||
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
|
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
|
||||||
|
@ -355,7 +373,7 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
||||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
save_checkpoint(network, cur_file, self._config.integrated_save,
|
||||||
self._config.async_save)
|
self._config.async_save, self._config.enc_key, self._config.enc_mode)
|
||||||
|
|
||||||
self._latest_ckpt_file_name = cur_file
|
self._latest_ckpt_file_name = cur_file
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Model and parameters serialization."""
|
"""Model and parameters serialization."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import stat
|
import stat
|
||||||
import math
|
import math
|
||||||
|
@ -40,7 +41,7 @@ from mindspore._checkparam import check_input_data, Validator
|
||||||
from mindspore.compression.export import quant_export
|
from mindspore.compression.export import quant_export
|
||||||
from mindspore.parallel._tensor import _load_tensor
|
from mindspore.parallel._tensor import _load_tensor
|
||||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||||
from .._c_expression import load_mindir
|
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
||||||
|
|
||||||
|
|
||||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||||
|
@ -120,14 +121,19 @@ def _update_param(param, new_param):
|
||||||
param.set_data(type(param.data)(new_param.data))
|
param.set_data(type(param.data)(new_param.data))
|
||||||
|
|
||||||
|
|
||||||
def _exec_save(ckpt_file_name, data_list):
|
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
||||||
"""Execute the process of saving checkpoint into file."""
|
"""Execute the process of saving checkpoint into file."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
MAX_BLOCK_SIZE = 1024*1024*512
|
||||||
with _ckpt_mutex:
|
with _ckpt_mutex:
|
||||||
if os.path.exists(ckpt_file_name):
|
if os.path.exists(ckpt_file_name):
|
||||||
os.remove(ckpt_file_name)
|
os.remove(ckpt_file_name)
|
||||||
with open(ckpt_file_name, "ab") as f:
|
with open(ckpt_file_name, "ab") as f:
|
||||||
|
if enc_key is not None:
|
||||||
|
plain_data = bytes(0)
|
||||||
|
cipher_data = bytes(0)
|
||||||
|
|
||||||
for name, value in data_list.items():
|
for name, value in data_list.items():
|
||||||
data_size = value[2].nbytes / 1024
|
data_size = value[2].nbytes / 1024
|
||||||
if data_size > SLICE_SIZE:
|
if data_size > SLICE_SIZE:
|
||||||
|
@ -145,7 +151,19 @@ def _exec_save(ckpt_file_name, data_list):
|
||||||
param_tensor.tensor_type = value[1]
|
param_tensor.tensor_type = value[1]
|
||||||
param_tensor.tensor_content = param_slice.tobytes()
|
param_tensor.tensor_content = param_slice.tobytes()
|
||||||
|
|
||||||
f.write(checkpoint_list.SerializeToString())
|
if enc_key is None:
|
||||||
|
f.write(checkpoint_list.SerializeToString())
|
||||||
|
else:
|
||||||
|
plain_data += checkpoint_list.SerializeToString()
|
||||||
|
while len(plain_data) >= MAX_BLOCK_SIZE:
|
||||||
|
cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key,
|
||||||
|
len(enc_key), enc_mode)
|
||||||
|
plain_data = plain_data[MAX_BLOCK_SIZE:]
|
||||||
|
|
||||||
|
if enc_key is not None:
|
||||||
|
if plain_data:
|
||||||
|
cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode)
|
||||||
|
f.write(cipher_data)
|
||||||
|
|
||||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||||
|
|
||||||
|
@ -154,7 +172,7 @@ def _exec_save(ckpt_file_name, data_list):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint info to a specified file.
|
Saves checkpoint info to a specified file.
|
||||||
|
|
||||||
|
@ -166,6 +184,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
||||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
|
||||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||||
|
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
||||||
|
is not required. Default: None.
|
||||||
|
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
||||||
|
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
|
TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
|
||||||
|
@ -176,6 +198,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
||||||
integrated_save = Validator.check_bool(integrated_save)
|
integrated_save = Validator.check_bool(integrated_save)
|
||||||
async_save = Validator.check_bool(async_save)
|
async_save = Validator.check_bool(async_save)
|
||||||
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
||||||
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
||||||
|
|
||||||
logger.info("Execute the process of saving checkpoint files.")
|
logger.info("Execute the process of saving checkpoint files.")
|
||||||
|
|
||||||
|
@ -218,10 +242,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
data_list[key].append(data)
|
data_list[key].append(data)
|
||||||
|
|
||||||
if async_save:
|
if async_save:
|
||||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
|
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt")
|
||||||
thr.start()
|
thr.start()
|
||||||
else:
|
else:
|
||||||
_exec_save(ckpt_file_name, data_list)
|
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode)
|
||||||
|
|
||||||
logger.info("Saving checkpoint process is finished.")
|
logger.info("Saving checkpoint process is finished.")
|
||||||
|
|
||||||
|
@ -278,7 +302,7 @@ def load(file_name):
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
|
||||||
"""
|
"""
|
||||||
Loads checkpoint info from a specified file.
|
Loads checkpoint info from a specified file.
|
||||||
|
|
||||||
|
@ -289,6 +313,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
in the param_dict into net with the same suffix. Default: False
|
in the param_dict into net with the same suffix. Default: False
|
||||||
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
|
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
|
||||||
will not be loaded. Default: None.
|
will not be loaded. Default: None.
|
||||||
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
||||||
|
is not required. Default: None.
|
||||||
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
||||||
|
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict, key is parameter name, value is a Parameter.
|
Dict, key is parameter name, value is a Parameter.
|
||||||
|
@ -303,15 +331,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||||
"""
|
"""
|
||||||
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
||||||
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||||
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
||||||
logger.info("Execute the process of loading checkpoint files.")
|
logger.info("Execute the process of loading checkpoint files.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = Checkpoint()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(ckpt_file_name, "rb") as f:
|
if dec_key is None:
|
||||||
pb_content = f.read()
|
with open(ckpt_file_name, "rb") as f:
|
||||||
|
pb_content = f.read()
|
||||||
|
else:
|
||||||
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
||||||
checkpoint_list.ParseFromString(pb_content)
|
checkpoint_list.ParseFromString(pb_content)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
|
if _is_cipher_file(ckpt_file_name):
|
||||||
|
logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the "
|
||||||
|
"dec_key.", ckpt_file_name)
|
||||||
|
else:
|
||||||
|
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \
|
||||||
|
ckpt_file_name)
|
||||||
raise ValueError(e.__str__())
|
raise ValueError(e.__str__())
|
||||||
|
|
||||||
parameter_dict = {}
|
parameter_dict = {}
|
||||||
|
@ -1082,7 +1120,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
||||||
return merged_parameter
|
return merged_parameter
|
||||||
|
|
||||||
|
|
||||||
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
|
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'):
|
||||||
"""
|
"""
|
||||||
Load checkpoint into net for distributed predication.
|
Load checkpoint into net for distributed predication.
|
||||||
|
|
||||||
|
@ -1095,6 +1133,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
||||||
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
|
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
|
||||||
it means that the predication process just uses single device.
|
it means that the predication process just uses single device.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
||||||
|
is not required. Default: None.
|
||||||
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
||||||
|
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: The type of inputs do not match the requirements.
|
TypeError: The type of inputs do not match the requirements.
|
||||||
|
@ -1113,6 +1155,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
||||||
f"dev_matrix (list[int]), tensor_map (list[int]), "
|
f"dev_matrix (list[int]), tensor_map (list[int]), "
|
||||||
f"param_split_shape (list[int]) and field_size (zero).")
|
f"param_split_shape (list[int]) and field_size (zero).")
|
||||||
|
|
||||||
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||||
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
||||||
|
|
||||||
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
||||||
_train_strategy = build_searched_strategy(train_strategy_filename)
|
_train_strategy = build_searched_strategy(train_strategy_filename)
|
||||||
train_strategy = _convert_to_list(_train_strategy)
|
train_strategy = _convert_to_list(_train_strategy)
|
||||||
|
@ -1135,7 +1180,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
||||||
param_rank = rank_list[param.name][0]
|
param_rank = rank_list[param.name][0]
|
||||||
skip_merge_split = rank_list[param.name][1]
|
skip_merge_split = rank_list[param.name][1]
|
||||||
for rank in param_rank:
|
for rank in param_rank:
|
||||||
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name]
|
sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name]
|
||||||
sliced_params.append(sliced_param)
|
sliced_params.append(sliced_param)
|
||||||
if skip_merge_split:
|
if skip_merge_split:
|
||||||
split_param = sliced_params[0]
|
split_param = sliced_params[0]
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""test callback function."""
|
"""test callback function."""
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import stat
|
import stat
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
@ -246,6 +247,43 @@ def test_checkpoint_save_ckpt_seconds():
|
||||||
ckpt_cb2.step_end(run_context)
|
ckpt_cb2.step_end(run_context)
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_save_ckpt_with_encryption():
|
||||||
|
"""Test checkpoint save ckpt with encryption."""
|
||||||
|
train_config = CheckpointConfig(
|
||||||
|
save_checkpoint_steps=16,
|
||||||
|
save_checkpoint_seconds=0,
|
||||||
|
keep_checkpoint_max=5,
|
||||||
|
keep_checkpoint_per_n_minutes=0,
|
||||||
|
enc_key=os.urandom(16),
|
||||||
|
enc_mode="AES-GCM")
|
||||||
|
ckpt_cb = ModelCheckpoint(config=train_config)
|
||||||
|
cb_params = _InternalCallbackParam()
|
||||||
|
net = Net()
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
network_ = WithLossCell(net, loss)
|
||||||
|
_train_network = TrainOneStepCell(network_, optim)
|
||||||
|
cb_params.train_network = _train_network
|
||||||
|
cb_params.epoch_num = 10
|
||||||
|
cb_params.cur_epoch_num = 5
|
||||||
|
cb_params.cur_step_num = 160
|
||||||
|
cb_params.batch_num = 32
|
||||||
|
run_context = RunContext(cb_params)
|
||||||
|
ckpt_cb.begin(run_context)
|
||||||
|
ckpt_cb.step_end(run_context)
|
||||||
|
ckpt_cb2 = ModelCheckpoint(config=train_config)
|
||||||
|
cb_params.cur_epoch_num = 1
|
||||||
|
cb_params.cur_step_num = 15
|
||||||
|
|
||||||
|
if platform.system().lower() == "windows":
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
ckpt_cb2.begin(run_context)
|
||||||
|
ckpt_cb2.step_end(run_context)
|
||||||
|
else:
|
||||||
|
ckpt_cb2.begin(run_context)
|
||||||
|
ckpt_cb2.step_end(run_context)
|
||||||
|
|
||||||
|
|
||||||
def test_CallbackManager():
|
def test_CallbackManager():
|
||||||
"""TestCallbackManager."""
|
"""TestCallbackManager."""
|
||||||
ck_obj = ModelCheckpoint()
|
ck_obj = ModelCheckpoint()
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""ut for model serialize(save/load)"""
|
"""ut for model serialize(save/load)"""
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import stat
|
import stat
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -299,6 +300,30 @@ def test_load_checkpoint_empty_file():
|
||||||
load_checkpoint("empty.ckpt")
|
load_checkpoint("empty.ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_checkpoint_for_network_with_encryption():
|
||||||
|
""" test save and checkpoint for network with encryption"""
|
||||||
|
net = Net()
|
||||||
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||||
|
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
||||||
|
|
||||||
|
loss_net = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(loss_net, opt)
|
||||||
|
key = os.urandom(16)
|
||||||
|
mode = "AES-GCM"
|
||||||
|
ckpt_path = "./encrypt_ckpt.ckpt"
|
||||||
|
if platform.system().lower() == "windows":
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
|
||||||
|
param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
else:
|
||||||
|
save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
|
||||||
|
param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
if os.path.exists(ckpt_path):
|
||||||
|
os.remove(ckpt_path)
|
||||||
|
|
||||||
|
|
||||||
class MYNET(nn.Cell):
|
class MYNET(nn.Cell):
|
||||||
""" NET definition """
|
""" NET definition """
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue