Add encryption for ckpt file

This commit is contained in:
liuluobin 2021-04-27 10:20:04 +08:00
parent 55bdb33f35
commit a40229d171
12 changed files with 625 additions and 21 deletions

View File

@ -5,10 +5,14 @@ else()
set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz")
set(MD5 "bdd51a68ad74618dd2519da8e0bcc759")
endif()
mindspore_add_pkg(openssl
VER 1.1.0
LIBS ssl crypto
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./config no-zlib no-shared)
include_directories(${openssl_INC})
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
mindspore_add_pkg(openssl
VER 1.1.0
LIBS ssl crypto
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./config no-zlib no-shared)
include_directories(${openssl_INC})
add_library(mindspore::ssl ALIAS openssl::ssl)
add_library(mindspore::crypto ALIAS openssl::crypto)
endif()

View File

@ -226,6 +226,7 @@ set(SUB_COMP
pipeline/jit
pipeline/pynative
common debug pybind_api utils vm profiler ps
crypto
)
foreach(_comp ${SUB_COMP})

View File

@ -0,0 +1,6 @@
file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES})
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(_mindspore_crypto_obj mindspore::crypto)
endif()

View File

@ -0,0 +1,347 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crypto/crypto.h"
namespace mindspore {
namespace crypto {
int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; }
Byte *intToByte(const int32_t &n) {
Byte *byte = new Byte[4];
memset(byte, 0, sizeof(Byte) * 4);
byte[0] = (Byte)(0xFF & n);
byte[1] = (Byte)((0xFF00 & n) >> 8);
byte[2] = (Byte)((0xFF0000 & n) >> 16);
byte[3] = (Byte)((0xFF000000 & n) >> 24);
return byte;
}
int32_t ByteToint(const Byte *byteArray) {
int32_t res = byteArray[0] & 0xFF;
res |= ((byteArray[1] << 8) & 0xFF00);
res |= ((byteArray[2] << 16) & 0xFF0000);
res += ((byteArray[3] << 24) & 0xFF000000);
return res;
}
bool IsCipherFile(std::string file_path) {
char *int_buf = new char[4];
int flag = 0;
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
if (!fid) {
MS_LOG(ERROR) << "Open file failed";
exit(-1);
}
fid.read(int_buf, sizeof(int32_t));
fid.close();
flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
delete[] int_buf;
return flag == MAGIC_NUM;
}
#if defined(_WIN32)
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
#else
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
Byte **cipher_data, int32_t *cipher_len) {
// Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data
Byte buf[4];
memcpy(buf, encrypt_data, 4);
*iv_len = ByteToint(buf);
memcpy(buf, encrypt_data + *iv_len + 4, 4);
*cipher_len = ByteToint(buf);
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false;
}
*iv = new Byte[*iv_len];
memcpy(*iv, encrypt_data + 4, *iv_len);
*cipher_data = new Byte[*cipher_len];
memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len);
return true;
}
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
std::smatch results;
std::regex re("([A-Z]{3})-([A-Z]{3})");
if (!std::regex_match(mode.c_str(), re)) {
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
return false;
}
std::regex_search(mode, results, re);
*alg_mode = results[1];
*work_mode = results[2];
return true;
}
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
int flag) {
int ret = 0;
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (work_mode != "GCM" && work_mode != "CBC") {
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
return nullptr;
}
const EVP_CIPHER *(*funcPtr)() = nullptr;
if (work_mode == "GCM") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_gcm;
break;
case 24:
funcPtr = EVP_aes_192_gcm;
break;
case 32:
funcPtr = EVP_aes_256_gcm;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
} else if (work_mode == "CBC") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_cbc;
break;
case 24:
funcPtr = EVP_aes_192_cbc;
break;
case 32:
funcPtr = EVP_aes_256_cbc;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
}
if (flag == 0) {
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
} else if (flag == 1) {
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
return nullptr;
}
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
return ctx;
}
bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
const int32_t key_len, const std::string &enc_mode) {
// Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length +
// iv length + iv + plain text length + cipher text length + cipher text"
int32_t cipher_len = 0; // cipher length
int32_t iv_len = AES_BLOCK_SIZE;
Byte *iv = new Byte[iv_len];
RAND_bytes(iv, sizeof(Byte) * iv_len);
Byte *iv_cpy = new Byte[16];
memcpy(iv_cpy, iv, 16);
// set the encryption length
int32_t ret = 0;
int32_t flen = 0;
std::string alg_mode;
std::string work_mode;
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
return false;
}
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
Byte *cipher_data;
cipher_data = new Byte[plain_len + 16];
ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
delete[] cipher_data;
return false;
}
if (work_mode == "CBC") {
EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen);
cipher_len += flen;
}
EVP_CIPHER_CTX_free(ctx);
int64_t cur = 0;
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接
memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
cur += 4;
memcpy(encrypt_data + cur, intToByte(iv_len), 4);
cur += 4;
memcpy(encrypt_data + cur, iv_cpy, iv_len);
cur += iv_len;
memcpy(encrypt_data + cur, intToByte(cipher_len), 4);
cur += 4;
memcpy(encrypt_data + cur, cipher_data, cipher_len);
*encrypt_data_len += 4;
delete[] cipher_data;
return true;
}
bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
const int32_t key_len, const std::string &dec_mode) {
// Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv +
// plain text data length + cipher text data length + cipher text data"
std::string alg_mode;
std::string work_mode;
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
return false;
}
// 解析加密数据
int32_t iv_len = 0;
int32_t cipher_len = 0;
Byte *iv = NULL;
Byte *cipher_data = NULL;
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) {
return false;
}
*plain_data = new Byte[cipher_len + 16];
if (*plain_data == NULL) {
MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string.";
return false;
}
// 解密密文
int ret = 0;
int mlen = 0;
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
return false;
}
if (work_mode == "CBC") {
ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
}
delete[] iv;
delete[] cipher_data;
EVP_CIPHER_CTX_free(ctx);
return true;
}
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode) {
int64_t cur_pos = 0;
int64_t block_enc_len = 0;
int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
Byte *encrypt_data = new Byte[encrypt_buf_len];
Byte *block_buf = new Byte[MAX_BLOCK_SIZE];
Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100];
*encrypt_len = 0;
while (cur_pos < plain_len) {
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
memcpy(block_buf, plain_data + cur_pos, cur_block_size);
if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
delete[] block_buf;
delete[] block_enc_buf;
delete[] encrypt_data;
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
}
memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
*encrypt_len += sizeof(int32_t);
memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len);
*encrypt_len += block_enc_len;
cur_pos += cur_block_size;
}
delete[] block_buf;
delete[] block_enc_buf;
return encrypt_data;
}
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode) {
Byte *decrypt_data = nullptr;
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
char *int_buf = new char[4];
// Byte *decrypt_block_buf = new Byte[100];
Byte *decrypt_block_buf = nullptr;
int32_t decrypt_block_len;
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
if (!fid) {
MS_LOG(ERROR) << "Open file failed";
exit(-1);
}
fid.seekg(0, std::ios_base::end);
int64_t file_size = fid.tellg();
fid.clear();
fid.seekg(0);
decrypt_data = new Byte[file_size];
*decrypt_len = 0;
while (fid.tellg() < file_size) {
fid.read(int_buf, sizeof(int32_t));
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
if (cipher_flag != MAGIC_NUM) {
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
<< "\"is not an encrypted file and cannot be decrypted";
}
fid.read(int_buf, sizeof(int32_t));
int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
fid.read(block_buf, sizeof(char) * block_size);
if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
key_len, dec_mode))) {
delete[] block_buf;
delete[] int_buf;
delete[] decrypt_data;
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
}
memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len);
*decrypt_len += decrypt_block_len;
}
fid.close();
delete[] block_buf;
delete[] int_buf;
return decrypt_data;
}
#endif
} // namespace crypto
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
#if not defined(_WIN32)
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#endif
#include <stdio.h>
#include <fstream>
#include <string>
#include <regex>
#include "utils/log_adapter.h"
typedef unsigned char Byte;
namespace mindspore {
namespace crypto {
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode);
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode);
bool IsCipherFile(const std::string file_path);
} // namespace crypto
} // namespace mindspore
#endif

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crypto/crypto_pybind.h"
namespace mindspore {
namespace crypto {
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) {
int64_t encrypt_len;
char *encrypt_data;
encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
reinterpret_cast<Byte *>(key), key_len, enc_mode));
return py::bytes(encrypt_data, encrypt_len);
}
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) {
int64_t decrypt_len;
char *decrypt_data;
decrypt_data = reinterpret_cast<char *>(
Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode));
return py::bytes(decrypt_data, decrypt_len);
}
bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); }
} // namespace crypto
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
#include "crypto/crypto.h"
#include <pybind11/pybind11.h>
#include <string>
namespace py = pybind11;
namespace mindspore {
namespace crypto {
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode);
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode);
bool PyIsCipherFile(std::string file_path);
} // namespace crypto
} // namespace mindspore
#endif

View File

@ -28,6 +28,7 @@
#include "utils/mpi/mpi_config.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/costmodel_context.h"
#include "crypto/crypto_pybind.h"
#ifdef ENABLE_GPU_COLLECTIVE
#include "runtime/device/gpu/distribution/collective_init.h"
#else
@ -330,4 +331,8 @@ PYBIND11_MODULE(_c_expression, m) {
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
(void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted");
}

View File

@ -82,6 +82,10 @@ class CheckpointConfig:
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
with the network in training, the initial value of saved_network will be saved. Default: None.
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
Raises:
ValueError: If the input_param is None or 0.
@ -126,7 +130,9 @@ class CheckpointConfig:
keep_checkpoint_per_n_minutes=0,
integrated_save=True,
async_save=False,
saved_network=None):
saved_network=None,
enc_key=None,
enc_mode='AES-GCM'):
if save_checkpoint_steps is not None:
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
@ -160,6 +166,8 @@ class CheckpointConfig:
self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = Validator.check_bool(async_save)
self._saved_network = saved_network
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@property
def save_checkpoint_steps(self):
@ -196,6 +204,16 @@ class CheckpointConfig:
"""Get the value of _saved_network"""
return self._saved_network
@property
def enc_key(self):
"""Get the value of _enc_key"""
return self._enc_key
@property
def enc_mode(self):
"""Get the value of _enc_mode"""
return self._enc_mode
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
@ -355,7 +373,7 @@ class ModelCheckpoint(Callback):
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
save_checkpoint(network, cur_file, self._config.integrated_save,
self._config.async_save)
self._config.async_save, self._config.enc_key, self._config.enc_mode)
self._latest_ckpt_file_name = cur_file

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Model and parameters serialization."""
import os
import sys
import stat
import math
@ -40,7 +41,7 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from .._c_expression import load_mindir
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@ -120,14 +121,19 @@ def _update_param(param, new_param):
param.set_data(type(param.data)(new_param.data))
def _exec_save(ckpt_file_name, data_list):
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
"""Execute the process of saving checkpoint into file."""
try:
MAX_BLOCK_SIZE = 1024*1024*512
with _ckpt_mutex:
if os.path.exists(ckpt_file_name):
os.remove(ckpt_file_name)
with open(ckpt_file_name, "ab") as f:
if enc_key is not None:
plain_data = bytes(0)
cipher_data = bytes(0)
for name, value in data_list.items():
data_size = value[2].nbytes / 1024
if data_size > SLICE_SIZE:
@ -145,7 +151,19 @@ def _exec_save(ckpt_file_name, data_list):
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = param_slice.tobytes()
f.write(checkpoint_list.SerializeToString())
if enc_key is None:
f.write(checkpoint_list.SerializeToString())
else:
plain_data += checkpoint_list.SerializeToString()
while len(plain_data) >= MAX_BLOCK_SIZE:
cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key,
len(enc_key), enc_mode)
plain_data = plain_data[MAX_BLOCK_SIZE:]
if enc_key is not None:
if plain_data:
cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode)
f.write(cipher_data)
os.chmod(ckpt_file_name, stat.S_IRUSR)
@ -154,7 +172,7 @@ def _exec_save(ckpt_file_name, data_list):
raise e
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"):
"""
Saves checkpoint info to a specified file.
@ -166,6 +184,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
Raises:
TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
@ -176,6 +198,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save)
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
logger.info("Execute the process of saving checkpoint files.")
@ -218,10 +242,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
data_list[key].append(data)
if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt")
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode)
logger.info("Saving checkpoint process is finished.")
@ -278,7 +302,7 @@ def load(file_name):
return graph
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
"""
Loads checkpoint info from a specified file.
@ -289,6 +313,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
in the param_dict into net with the same suffix. Default: False
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix
will not be loaded. Default: None.
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
is not required. Default: None.
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
Returns:
Dict, key is parameter name, value is a Parameter.
@ -303,15 +331,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
"""
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()
try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
if dec_key is None:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
if _is_cipher_file(ckpt_file_name):
logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the "
"dec_key.", ckpt_file_name)
else:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \
ckpt_file_name)
raise ValueError(e.__str__())
parameter_dict = {}
@ -1075,7 +1113,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
return merged_parameter
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'):
"""
Load checkpoint into net for distributed predication.
@ -1088,6 +1126,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device.
Default: None.
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
is not required. Default: None.
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
Raises:
TypeError: The type of inputs do not match the requirements.
@ -1106,6 +1148,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
_train_strategy = build_searched_strategy(train_strategy_filename)
train_strategy = _convert_to_list(_train_strategy)
@ -1128,7 +1173,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
for rank in param_rank:
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name]
sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name]
sliced_params.append(sliced_param)
if skip_merge_split:
split_param = sliced_params[0]

View File

@ -14,6 +14,7 @@
# ============================================================================
"""test callback function."""
import os
import platform
import stat
from unittest import mock
@ -246,6 +247,43 @@ def test_checkpoint_save_ckpt_seconds():
ckpt_cb2.step_end(run_context)
def test_checkpoint_save_ckpt_with_encryption():
"""Test checkpoint save ckpt with encryption."""
train_config = CheckpointConfig(
save_checkpoint_steps=16,
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
enc_key=os.urandom(16),
enc_mode="AES-GCM")
ckpt_cb = ModelCheckpoint(config=train_config)
cb_params = _InternalCallbackParam()
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits()
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
network_ = WithLossCell(net, loss)
_train_network = TrainOneStepCell(network_, optim)
cb_params.train_network = _train_network
cb_params.epoch_num = 10
cb_params.cur_epoch_num = 5
cb_params.cur_step_num = 160
cb_params.batch_num = 32
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
ckpt_cb.step_end(run_context)
ckpt_cb2 = ModelCheckpoint(config=train_config)
cb_params.cur_epoch_num = 1
cb_params.cur_step_num = 15
if platform.system().lower() == "windows":
with pytest.raises(NotImplementedError):
ckpt_cb2.begin(run_context)
ckpt_cb2.step_end(run_context)
else:
ckpt_cb2.begin(run_context)
ckpt_cb2.step_end(run_context)
def test_CallbackManager():
"""TestCallbackManager."""
ck_obj = ModelCheckpoint()

View File

@ -14,6 +14,7 @@
# ============================================================================
"""ut for model serialize(save/load)"""
import os
import platform
import stat
import time
@ -299,6 +300,30 @@ def test_load_checkpoint_empty_file():
load_checkpoint("empty.ckpt")
def test_save_and_load_checkpoint_for_network_with_encryption():
""" test save and checkpoint for network with encryption"""
net = Net()
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt)
key = os.urandom(16)
mode = "AES-GCM"
ckpt_path = "./encrypt_ckpt.ckpt"
if platform.system().lower() == "windows":
with pytest.raises(NotImplementedError):
save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
load_param_into_net(net, param_dict)
else:
save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode)
param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM")
load_param_into_net(net, param_dict)
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
class MYNET(nn.Cell):
""" NET definition """