diff --git a/cmake/external_libs/openssl.cmake b/cmake/external_libs/openssl.cmake index 0b6cd35a2ce..2663c77daf3 100644 --- a/cmake/external_libs/openssl.cmake +++ b/cmake/external_libs/openssl.cmake @@ -5,10 +5,14 @@ else() set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz") set(MD5 "bdd51a68ad74618dd2519da8e0bcc759") endif() -mindspore_add_pkg(openssl - VER 1.1.0 - LIBS ssl crypto - URL ${REQ_URL} - MD5 ${MD5} - CONFIGURE_COMMAND ./config no-zlib no-shared) -include_directories(${openssl_INC}) \ No newline at end of file +if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + mindspore_add_pkg(openssl + VER 1.1.0 + LIBS ssl crypto + URL ${REQ_URL} + MD5 ${MD5} + CONFIGURE_COMMAND ./config no-zlib no-shared) + include_directories(${openssl_INC}) + add_library(mindspore::ssl ALIAS openssl::ssl) + add_library(mindspore::crypto ALIAS openssl::crypto) +endif() diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index d0da121c3bc..406cfb9047b 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -226,6 +226,7 @@ set(SUB_COMP pipeline/jit pipeline/pynative common debug pybind_api utils vm profiler ps + crypto ) foreach(_comp ${SUB_COMP}) diff --git a/mindspore/ccsrc/crypto/CMakeLists.txt b/mindspore/ccsrc/crypto/CMakeLists.txt new file mode 100644 index 00000000000..e96a7a11e61 --- /dev/null +++ b/mindspore/ccsrc/crypto/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES}) + +if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + target_link_libraries(_mindspore_crypto_obj mindspore::crypto) +endif() diff --git a/mindspore/ccsrc/crypto/crypto.cc b/mindspore/ccsrc/crypto/crypto.cc new file mode 100644 index 00000000000..748f73a6535 --- /dev/null +++ b/mindspore/ccsrc/crypto/crypto.cc @@ -0,0 +1,347 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "crypto/crypto.h" + +namespace mindspore { +namespace crypto { +int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; } + +Byte *intToByte(const int32_t &n) { + Byte *byte = new Byte[4]; + memset(byte, 0, sizeof(Byte) * 4); + byte[0] = (Byte)(0xFF & n); + byte[1] = (Byte)((0xFF00 & n) >> 8); + byte[2] = (Byte)((0xFF0000 & n) >> 16); + byte[3] = (Byte)((0xFF000000 & n) >> 24); + return byte; +} + +int32_t ByteToint(const Byte *byteArray) { + int32_t res = byteArray[0] & 0xFF; + res |= ((byteArray[1] << 8) & 0xFF00); + res |= ((byteArray[2] << 16) & 0xFF0000); + res += ((byteArray[3] << 24) & 0xFF000000); + return res; +} + +bool IsCipherFile(std::string file_path) { + char *int_buf = new char[4]; + int flag = 0; + std::ifstream fid(file_path, std::ios::in | std::ios::binary); + if (!fid) { + MS_LOG(ERROR) << "Open file failed"; + exit(-1); + } + fid.read(int_buf, sizeof(int32_t)); + fid.close(); + flag = ByteToint(reinterpret_cast(int_buf)); + delete[] int_buf; + return flag == MAGIC_NUM; +} +#if defined(_WIN32) +Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, + const std::string &enc_mode) { + MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; +} + +Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, + const std::string &dec_mode) { + MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; +} +#else + +bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, + Byte **cipher_data, int32_t *cipher_len) { + // Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data + Byte buf[4]; + memcpy(buf, encrypt_data, 4); + *iv_len = ByteToint(buf); + memcpy(buf, encrypt_data + *iv_len + 4, 4); + *cipher_len = ByteToint(buf); + if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { + MS_LOG(ERROR) << "Failed to parse encrypt data."; + return false; + } + *iv = new Byte[*iv_len]; + memcpy(*iv, encrypt_data + 4, *iv_len); + *cipher_data = new Byte[*cipher_len]; + memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len); + return true; +} + +bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) { + std::smatch results; + std::regex re("([A-Z]{3})-([A-Z]{3})"); + if (!std::regex_match(mode.c_str(), re)) { + MS_LOG(ERROR) << "Mode " << mode << " is invalid."; + return false; + } + std::regex_search(mode, results, re); + *alg_mode = results[1]; + *work_mode = results[2]; + return true; +} + +EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv, + int flag) { + int ret = 0; + EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); + if (work_mode != "GCM" && work_mode != "CBC") { + MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid."; + return nullptr; + } + + const EVP_CIPHER *(*funcPtr)() = nullptr; + if (work_mode == "GCM") { + switch (key_len) { + case 16: + funcPtr = EVP_aes_128_gcm; + break; + case 24: + funcPtr = EVP_aes_192_gcm; + break; + case 32: + funcPtr = EVP_aes_256_gcm; + break; + default: + MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; + } + } else if (work_mode == "CBC") { + switch (key_len) { + case 16: + funcPtr = EVP_aes_128_cbc; + break; + case 24: + funcPtr = EVP_aes_192_cbc; + break; + case 32: + funcPtr = EVP_aes_256_cbc; + break; + default: + MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; + } + } + + if (flag == 0) { + ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); + } else if (flag == 1) { + ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); + } + + if (ret != 1) { + MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; + return nullptr; + } + if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1); + return ctx; +} + +bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, + const int32_t key_len, const std::string &enc_mode) { + // Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length + + // iv length + iv + plain text length + cipher text length + cipher text" + int32_t cipher_len = 0; // cipher length + + int32_t iv_len = AES_BLOCK_SIZE; + Byte *iv = new Byte[iv_len]; + RAND_bytes(iv, sizeof(Byte) * iv_len); + + Byte *iv_cpy = new Byte[16]; + memcpy(iv_cpy, iv, 16); + + // set the encryption length + int32_t ret = 0; + int32_t flen = 0; + std::string alg_mode; + std::string work_mode; + if (!ParseMode(enc_mode, &alg_mode, &work_mode)) { + return false; + } + + auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0); + if (ctx == nullptr) { + MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; + return false; + } + + Byte *cipher_data; + cipher_data = new Byte[plain_len + 16]; + ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; + delete[] cipher_data; + return false; + } + if (work_mode == "CBC") { + EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen); + cipher_len += flen; + } + EVP_CIPHER_CTX_free(ctx); + + int64_t cur = 0; + *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接 + + memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); + cur += 4; + memcpy(encrypt_data + cur, intToByte(iv_len), 4); + cur += 4; + memcpy(encrypt_data + cur, iv_cpy, iv_len); + cur += iv_len; + memcpy(encrypt_data + cur, intToByte(cipher_len), 4); + cur += 4; + memcpy(encrypt_data + cur, cipher_data, cipher_len); + *encrypt_data_len += 4; + + delete[] cipher_data; + return true; +} + +bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, + const int32_t key_len, const std::string &dec_mode) { + // Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv + + // plain text data length + cipher text data length + cipher text data" + std::string alg_mode; + std::string work_mode; + + if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { + return false; + } + + // 解析加密数据 + int32_t iv_len = 0; + int32_t cipher_len = 0; + Byte *iv = NULL; + Byte *cipher_data = NULL; + + if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) { + return false; + } + *plain_data = new Byte[cipher_len + 16]; + if (*plain_data == NULL) { + MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string."; + return false; + } + + // 解密密文 + int ret = 0; + int mlen = 0; + + auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1); + if (ctx == nullptr) { + MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; + return false; + } + ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; + return false; + } + if (work_mode == "CBC") { + ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; + return false; + } + *plain_len += mlen; + } + delete[] iv; + delete[] cipher_data; + EVP_CIPHER_CTX_free(ctx); + return true; +} + +Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, + const std::string &enc_mode) { + int64_t cur_pos = 0; + int64_t block_enc_len = 0; + int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100; + Byte *encrypt_data = new Byte[encrypt_buf_len]; + Byte *block_buf = new Byte[MAX_BLOCK_SIZE]; + Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100]; + + *encrypt_len = 0; + while (cur_pos < plain_len) { + int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); + memcpy(block_buf, plain_data + cur_pos, cur_block_size); + + if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { + delete[] block_buf; + delete[] block_enc_buf; + delete[] encrypt_data; + MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; + } + memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); + *encrypt_len += sizeof(int32_t); + memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len); + *encrypt_len += block_enc_len; + cur_pos += cur_block_size; + } + delete[] block_buf; + delete[] block_enc_buf; + return encrypt_data; +} + +Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, + const std::string &dec_mode) { + Byte *decrypt_data = nullptr; + char *block_buf = new char[MAX_BLOCK_SIZE * 2]; + char *int_buf = new char[4]; + // Byte *decrypt_block_buf = new Byte[100]; + Byte *decrypt_block_buf = nullptr; + int32_t decrypt_block_len; + + std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary); + if (!fid) { + MS_LOG(ERROR) << "Open file failed"; + exit(-1); + } + fid.seekg(0, std::ios_base::end); + int64_t file_size = fid.tellg(); + fid.clear(); + fid.seekg(0); + decrypt_data = new Byte[file_size]; + + *decrypt_len = 0; + while (fid.tellg() < file_size) { + fid.read(int_buf, sizeof(int32_t)); + int cipher_flag = ByteToint(reinterpret_cast(int_buf)); + if (cipher_flag != MAGIC_NUM) { + MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path + << "\"is not an encrypted file and cannot be decrypted"; + } + fid.read(int_buf, sizeof(int32_t)); + + int64_t block_size = ByteToint(reinterpret_cast(int_buf)); + fid.read(block_buf, sizeof(char) * block_size); + if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast(block_buf), block_size, key, + key_len, dec_mode))) { + delete[] block_buf; + delete[] int_buf; + delete[] decrypt_data; + MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; + } + memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len); + *decrypt_len += decrypt_block_len; + } + fid.close(); + delete[] block_buf; + delete[] int_buf; + return decrypt_data; +} +#endif +} // namespace crypto +} // namespace mindspore diff --git a/mindspore/ccsrc/crypto/crypto.h b/mindspore/ccsrc/crypto/crypto.h new file mode 100644 index 00000000000..edf7e8176b7 --- /dev/null +++ b/mindspore/ccsrc/crypto/crypto.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H +#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H + +#if not defined(_WIN32) +#include +#include +#include +#endif + +#include +#include +#include +#include +#include "utils/log_adapter.h" + +typedef unsigned char Byte; + +namespace mindspore { +namespace crypto { +const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB +const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number + +Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, + const std::string &enc_mode); +Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, + const std::string &dec_mode); +bool IsCipherFile(const std::string file_path); +} // namespace crypto +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/crypto/crypto_pybind.cc b/mindspore/ccsrc/crypto/crypto_pybind.cc new file mode 100644 index 00000000000..59c83f2c6ff --- /dev/null +++ b/mindspore/ccsrc/crypto/crypto_pybind.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "crypto/crypto_pybind.h" +namespace mindspore { +namespace crypto { +py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) { + int64_t encrypt_len; + char *encrypt_data; + encrypt_data = reinterpret_cast(Encrypt(&encrypt_len, reinterpret_cast(plain_data), plain_len, + reinterpret_cast(key), key_len, enc_mode)); + return py::bytes(encrypt_data, encrypt_len); +} + +py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) { + int64_t decrypt_len; + char *decrypt_data; + decrypt_data = reinterpret_cast( + Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast(key), key_len, dec_mode)); + return py::bytes(decrypt_data, decrypt_len); +} +bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); } +} // namespace crypto +} // namespace mindspore diff --git a/mindspore/ccsrc/crypto/crypto_pybind.h b/mindspore/ccsrc/crypto/crypto_pybind.h new file mode 100644 index 00000000000..68135ba3763 --- /dev/null +++ b/mindspore/ccsrc/crypto/crypto_pybind.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H +#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H +#include "crypto/crypto.h" +#include +#include + +namespace py = pybind11; + +namespace mindspore { +namespace crypto { +py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode); +py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode); +bool PyIsCipherFile(std::string file_path); +} // namespace crypto +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 3c183b65a5a..5f4f0523299 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -28,6 +28,7 @@ #include "utils/mpi/mpi_config.h" #include "frontend/parallel/context.h" #include "frontend/parallel/costmodel_context.h" +#include "crypto/crypto_pybind.h" #ifdef ENABLE_GPU_COLLECTIVE #include "runtime/device/gpu/distribution/collective_init.h" #else @@ -330,4 +331,8 @@ PYBIND11_MODULE(_c_expression, m) { (void)py::class_>(m, "OpInfoLoaderPy") .def(py::init()) .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); + + (void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data."); + (void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data."); + (void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted"); } diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index bb79c75a6d0..70bb83c6461 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -82,6 +82,10 @@ class CheckpointConfig: async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: None. + enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption + is not required. Default: None. + enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption + mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. Raises: ValueError: If the input_param is None or 0. @@ -126,7 +130,9 @@ class CheckpointConfig: keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, - saved_network=None): + saved_network=None, + enc_key=None, + enc_mode='AES-GCM'): if save_checkpoint_steps is not None: save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) @@ -160,6 +166,8 @@ class CheckpointConfig: self._integrated_save = Validator.check_bool(integrated_save) self._async_save = Validator.check_bool(async_save) self._saved_network = saved_network + self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) + self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) @property def save_checkpoint_steps(self): @@ -196,6 +204,16 @@ class CheckpointConfig: """Get the value of _saved_network""" return self._saved_network + @property + def enc_key(self): + """Get the value of _enc_key""" + return self._enc_key + + @property + def enc_mode(self): + """Get the value of _enc_mode""" + return self._enc_mode + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, @@ -355,7 +373,7 @@ class ModelCheckpoint(Callback): network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network save_checkpoint(network, cur_file, self._config.integrated_save, - self._config.async_save) + self._config.async_save, self._config.enc_key, self._config.enc_mode) self._latest_ckpt_file_name = cur_file diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 46039952ee7..8d61c1ae2eb 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -14,6 +14,7 @@ # ============================================================================ """Model and parameters serialization.""" import os + import sys import stat import math @@ -40,7 +41,7 @@ from mindspore._checkparam import check_input_data, Validator from mindspore.compression.export import quant_export from mindspore.parallel._tensor import _load_tensor from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices -from .._c_expression import load_mindir +from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, @@ -120,14 +121,19 @@ def _update_param(param, new_param): param.set_data(type(param.data)(new_param.data)) -def _exec_save(ckpt_file_name, data_list): +def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): """Execute the process of saving checkpoint into file.""" try: + MAX_BLOCK_SIZE = 1024*1024*512 with _ckpt_mutex: if os.path.exists(ckpt_file_name): os.remove(ckpt_file_name) with open(ckpt_file_name, "ab") as f: + if enc_key is not None: + plain_data = bytes(0) + cipher_data = bytes(0) + for name, value in data_list.items(): data_size = value[2].nbytes / 1024 if data_size > SLICE_SIZE: @@ -145,7 +151,19 @@ def _exec_save(ckpt_file_name, data_list): param_tensor.tensor_type = value[1] param_tensor.tensor_content = param_slice.tobytes() - f.write(checkpoint_list.SerializeToString()) + if enc_key is None: + f.write(checkpoint_list.SerializeToString()) + else: + plain_data += checkpoint_list.SerializeToString() + while len(plain_data) >= MAX_BLOCK_SIZE: + cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key, + len(enc_key), enc_mode) + plain_data = plain_data[MAX_BLOCK_SIZE:] + + if enc_key is not None: + if plain_data: + cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode) + f.write(cipher_data) os.chmod(ckpt_file_name, stat.S_IRUSR) @@ -154,7 +172,7 @@ def _exec_save(ckpt_file_name, data_list): raise e -def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): +def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"): """ Saves checkpoint info to a specified file. @@ -166,6 +184,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False + enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption + is not required. Default: None. + enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption + mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. Raises: TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter @@ -176,6 +198,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) integrated_save = Validator.check_bool(integrated_save) async_save = Validator.check_bool(async_save) + enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) + enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) logger.info("Execute the process of saving checkpoint files.") @@ -218,10 +242,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F data_list[key].append(data) if async_save: - thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") + thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt") thr.start() else: - _exec_save(ckpt_file_name, data_list) + _exec_save(ckpt_file_name, data_list, enc_key, enc_mode) logger.info("Saving checkpoint process is finished.") @@ -278,7 +302,7 @@ def load(file_name): return graph -def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): +def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): """ Loads checkpoint info from a specified file. @@ -289,6 +313,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N in the param_dict into net with the same suffix. Default: False filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix will not be loaded. Default: None. + dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption + is not required. Default: None. + dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption + mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. Returns: Dict, key is parameter name, value is a Parameter. @@ -303,15 +331,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") """ ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) + dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) + dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() try: - with open(ckpt_file_name, "rb") as f: - pb_content = f.read() + if dec_key is None: + with open(ckpt_file_name, "rb") as f: + pb_content = f.read() + else: + pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) checkpoint_list.ParseFromString(pb_content) except BaseException as e: - logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) + if _is_cipher_file(ckpt_file_name): + logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the " + "dec_key.", ckpt_file_name) + else: + logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \ + ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} @@ -1075,7 +1113,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): return merged_parameter -def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): +def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'): """ Load checkpoint into net for distributed predication. @@ -1088,6 +1126,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, it means that the predication process just uses single device. Default: None. + dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption + is not required. Default: None. + dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption + mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. Raises: TypeError: The type of inputs do not match the requirements. @@ -1106,6 +1148,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= f"dev_matrix (list[int]), tensor_map (list[int]), " f"param_split_shape (list[int]) and field_size (zero).") + dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) + dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) + train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") _train_strategy = build_searched_strategy(train_strategy_filename) train_strategy = _convert_to_list(_train_strategy) @@ -1128,7 +1173,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= param_rank = rank_list[param.name][0] skip_merge_split = rank_list[param.name][1] for rank in param_rank: - sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name] + sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name] sliced_params.append(sliced_param) if skip_merge_split: split_param = sliced_params[0] diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 4a90d962e42..d5d94e6ae0e 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -14,6 +14,7 @@ # ============================================================================ """test callback function.""" import os +import platform import stat from unittest import mock @@ -246,6 +247,43 @@ def test_checkpoint_save_ckpt_seconds(): ckpt_cb2.step_end(run_context) +def test_checkpoint_save_ckpt_with_encryption(): + """Test checkpoint save ckpt with encryption.""" + train_config = CheckpointConfig( + save_checkpoint_steps=16, + save_checkpoint_seconds=0, + keep_checkpoint_max=5, + keep_checkpoint_per_n_minutes=0, + enc_key=os.urandom(16), + enc_mode="AES-GCM") + ckpt_cb = ModelCheckpoint(config=train_config) + cb_params = _InternalCallbackParam() + net = Net() + loss = nn.SoftmaxCrossEntropyWithLogits() + optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + network_ = WithLossCell(net, loss) + _train_network = TrainOneStepCell(network_, optim) + cb_params.train_network = _train_network + cb_params.epoch_num = 10 + cb_params.cur_epoch_num = 5 + cb_params.cur_step_num = 160 + cb_params.batch_num = 32 + run_context = RunContext(cb_params) + ckpt_cb.begin(run_context) + ckpt_cb.step_end(run_context) + ckpt_cb2 = ModelCheckpoint(config=train_config) + cb_params.cur_epoch_num = 1 + cb_params.cur_step_num = 15 + + if platform.system().lower() == "windows": + with pytest.raises(NotImplementedError): + ckpt_cb2.begin(run_context) + ckpt_cb2.step_end(run_context) + else: + ckpt_cb2.begin(run_context) + ckpt_cb2.step_end(run_context) + + def test_CallbackManager(): """TestCallbackManager.""" ck_obj = ModelCheckpoint() diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index bb8a21fd9b5..4b4fe44aa64 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -14,6 +14,7 @@ # ============================================================================ """ut for model serialize(save/load)""" import os +import platform import stat import time @@ -299,6 +300,30 @@ def test_load_checkpoint_empty_file(): load_checkpoint("empty.ckpt") +def test_save_and_load_checkpoint_for_network_with_encryption(): + """ test save and checkpoint for network with encryption""" + net = Net() + loss = SoftmaxCrossEntropyWithLogits(sparse=True) + opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) + + loss_net = WithLossCell(net, loss) + train_network = TrainOneStepCell(loss_net, opt) + key = os.urandom(16) + mode = "AES-GCM" + ckpt_path = "./encrypt_ckpt.ckpt" + if platform.system().lower() == "windows": + with pytest.raises(NotImplementedError): + save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) + param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") + load_param_into_net(net, param_dict) + else: + save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) + param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") + load_param_into_net(net, param_dict) + if os.path.exists(ckpt_path): + os.remove(ckpt_path) + + class MYNET(nn.Cell): """ NET definition """