Add encryption support for mindir

This commit is contained in:
liuluobin 2021-06-10 14:20:36 +08:00
parent 941835dcf7
commit 5b9b46224b
25 changed files with 759 additions and 513 deletions

View File

@ -27,21 +27,47 @@
#include "include/api/dual_abi_helper.h"
namespace mindspore {
using Key = struct Key {
const size_t max_key_len = 32;
size_t len;
unsigned char key[32];
Key(): len(0) {}
};
class MS_API Serialization {
public:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key, const std::string &dec_mode);
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
const std::string &dec_mode);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
private:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key, const std::vector<char> &dec_mode);
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
const std::vector<char> &dec_mode);
};
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key, const std::string &dec_mode) {
return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
}
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
return Load(StringToChar(file), model_type, graph);
}
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
const std::string &dec_mode) {
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -218,7 +218,6 @@ set(SUB_COMP
pipeline/jit
pipeline/pynative
common debug pybind_api utils vm profiler ps
crypto
)
foreach(_comp ${SUB_COMP})

View File

@ -1,6 +0,0 @@
file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES})
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(_mindspore_crypto_obj mindspore::crypto)
endif()

View File

@ -1,347 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crypto/crypto.h"
namespace mindspore {
namespace crypto {
int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; }
void *intToByte(Byte *byte, const int32_t &n) {
memset(byte, 0, sizeof(Byte) * 4);
byte[0] = (Byte)(0xFF & n);
byte[1] = (Byte)((0xFF00 & n) >> 8);
byte[2] = (Byte)((0xFF0000 & n) >> 16);
byte[3] = (Byte)((0xFF000000 & n) >> 24);
return byte;
}
int32_t ByteToint(const Byte *byteArray) {
int32_t res = byteArray[0] & 0xFF;
res |= ((byteArray[1] << 8) & 0xFF00);
res |= ((byteArray[2] << 16) & 0xFF0000);
res += ((byteArray[3] << 24) & 0xFF000000);
return res;
}
bool IsCipherFile(std::string file_path) {
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
if (!fid) {
return false;
}
char *int_buf = new char[4];
fid.read(int_buf, sizeof(int32_t));
fid.close();
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
delete[] int_buf;
return flag == MAGIC_NUM;
}
#if defined(_WIN32)
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
#else
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
Byte **cipher_data, int32_t *cipher_len) {
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
Byte buf[4];
memcpy_s(buf, 4, encrypt_data, 4);
*iv_len = ByteToint(buf);
memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
*cipher_len = ByteToint(buf);
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false;
}
*iv = new Byte[*iv_len];
memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
*cipher_data = new Byte[*cipher_len];
memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
return true;
}
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
std::smatch results;
std::regex re("([A-Z]{3})-([A-Z]{3})");
if (!std::regex_match(mode.c_str(), re)) {
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
return false;
}
std::regex_search(mode, results, re);
*alg_mode = results[1];
*work_mode = results[2];
return true;
}
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
int flag) {
int ret = 0;
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (work_mode != "GCM" && work_mode != "CBC") {
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
return nullptr;
}
const EVP_CIPHER *(*funcPtr)() = nullptr;
if (work_mode == "GCM") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_gcm;
break;
case 24:
funcPtr = EVP_aes_192_gcm;
break;
case 32:
funcPtr = EVP_aes_256_gcm;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
} else if (work_mode == "CBC") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_cbc;
break;
case 24:
funcPtr = EVP_aes_192_cbc;
break;
case 32:
funcPtr = EVP_aes_256_cbc;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
}
if (flag == 0) {
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
} else if (flag == 1) {
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
return nullptr;
}
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
return ctx;
}
bool _BlockEncrypt(Byte *encrypt_data, const int64_t encrypt_data_buf_len, int64_t *encrypt_data_len, Byte *plain_data,
const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) {
int32_t cipher_len = 0;
int32_t iv_len = AES_BLOCK_SIZE;
Byte *iv = new Byte[iv_len];
RAND_bytes(iv, sizeof(Byte) * iv_len);
Byte *iv_cpy = new Byte[16];
memcpy_s(iv_cpy, 16, iv, 16);
int32_t ret = 0;
int32_t flen = 0;
std::string alg_mode;
std::string work_mode;
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
return false;
}
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
Byte *cipher_data;
cipher_data = new Byte[plain_len + 16];
ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
delete[] cipher_data;
return false;
}
if (work_mode == "CBC") {
EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen);
cipher_len += flen;
}
EVP_CIPHER_CTX_free(ctx);
int64_t cur = 0;
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
Byte *byte_buf = new Byte[4];
intToByte(byte_buf, *encrypt_data_len);
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
cur += 4;
intToByte(byte_buf, iv_len);
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
cur += 4;
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, iv_cpy, iv_len);
cur += iv_len;
intToByte(byte_buf, cipher_len);
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
cur += 4;
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, cipher_data, cipher_len);
*encrypt_data_len += 4;
delete[] byte_buf;
delete[] cipher_data;
return true;
}
bool _BlockDecrypt(Byte *plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
const int32_t key_len, const std::string &dec_mode) {
std::string alg_mode;
std::string work_mode;
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
return false;
}
int32_t iv_len = 0;
int32_t cipher_len = 0;
Byte *iv = NULL;
Byte *cipher_data = NULL;
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) {
return false;
}
int ret = 0;
int mlen = 0;
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
ret = EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data, cipher_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
return false;
}
if (work_mode == "CBC") {
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
}
delete[] iv;
delete[] cipher_data;
EVP_CIPHER_CTX_free(ctx);
return true;
}
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode) {
int64_t cur_pos = 0;
int64_t block_enc_len = 0;
int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
int64_t block_enc_buf_len = MAX_BLOCK_SIZE + 100;
Byte *byte_buf = new Byte[4];
Byte *encrypt_data = new Byte[encrypt_buf_len];
Byte *block_buf = new Byte[MAX_BLOCK_SIZE];
Byte *block_enc_buf = new Byte[block_enc_buf_len];
*encrypt_len = 0;
while (cur_pos < plain_len) {
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);
if (!_BlockEncrypt(block_enc_buf, block_enc_buf_len, &block_enc_len, block_buf, cur_block_size, key, key_len,
enc_mode)) {
delete[] byte_buf;
delete[] block_buf;
delete[] block_enc_buf;
delete[] encrypt_data;
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
}
intToByte(byte_buf, MAGIC_NUM);
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, byte_buf, sizeof(int32_t));
*encrypt_len += sizeof(int32_t);
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
*encrypt_len += block_enc_len;
cur_pos += cur_block_size;
}
delete[] byte_buf;
delete[] block_buf;
delete[] block_enc_buf;
return encrypt_data;
}
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode) {
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
char *int_buf = new char[4];
Byte *decrypt_block_buf = new Byte[MAX_BLOCK_SIZE * 2];
int32_t decrypt_block_len;
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
if (!fid) {
MS_EXCEPTION(ValueError) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file.";
}
fid.seekg(0, std::ios_base::end);
int64_t file_size = fid.tellg();
fid.clear();
fid.seekg(0);
Byte *decrypt_data = new Byte[file_size];
*decrypt_len = 0;
while (fid.tellg() < file_size) {
fid.read(int_buf, sizeof(int32_t));
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
if (cipher_flag != MAGIC_NUM) {
delete[] block_buf;
delete[] int_buf;
delete[] decrypt_block_buf;
delete[] decrypt_data;
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
<< "\"is not an encrypted file and cannot be decrypted";
}
fid.read(int_buf, sizeof(int32_t));
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
fid.read(block_buf, sizeof(char) * block_size);
if (!(_BlockDecrypt(decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
key_len, dec_mode))) {
delete[] block_buf;
delete[] int_buf;
delete[] decrypt_block_buf;
delete[] decrypt_data;
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
}
auto destMax = Min(file_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
memcpy_s(decrypt_data + *decrypt_len, destMax, decrypt_block_buf, decrypt_block_len);
*decrypt_len += decrypt_block_len;
}
fid.close();
delete[] block_buf;
delete[] int_buf;
delete[] decrypt_block_buf;
return decrypt_data;
}
#endif
} // namespace crypto
} // namespace mindspore

View File

@ -1,46 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
#if not defined(_WIN32)
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#endif
#include <stdio.h>
#include <fstream>
#include <string>
#include <regex>
#include "utils/log_adapter.h"
typedef unsigned char Byte;
namespace mindspore {
namespace crypto {
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
const std::string &enc_mode);
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
const std::string &dec_mode);
bool IsCipherFile(const std::string file_path);
} // namespace crypto
} // namespace mindspore
#endif

View File

@ -1,41 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "crypto/crypto_pybind.h"
namespace mindspore {
namespace crypto {
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) {
int64_t encrypt_len;
char *encrypt_data;
encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
reinterpret_cast<Byte *>(key), key_len, enc_mode));
auto py_encrypt_data = py::bytes(encrypt_data, encrypt_len);
delete[] encrypt_data;
return py_encrypt_data;
}
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) {
int64_t decrypt_len;
char *decrypt_data;
decrypt_data = reinterpret_cast<char *>(
Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode));
auto py_decrypt_data = py::bytes(decrypt_data, decrypt_len);
delete[] decrypt_data;
return py_decrypt_data;
}
bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); }
} // namespace crypto
} // namespace mindspore

View File

@ -1,32 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
#include "crypto/crypto.h"
#include <pybind11/pybind11.h>
#include <string>
namespace py = pybind11;
namespace mindspore {
namespace crypto {
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode);
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode);
bool PyIsCipherFile(std::string file_path);
} // namespace crypto
} // namespace mindspore
#endif

View File

@ -18,6 +18,7 @@
#include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h"
#include "mindspore/core/load_mindir/load_model.h"
#include "utils/crypto.h"
namespace mindspore {
static Buffer ReadFile(const std::string &file) {
@ -78,7 +79,47 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
} catch (const std::exception &) {
MS_LOG(ERROR) << "Load model failed.";
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
MS_LOG(ERROR) << "Load model failed. The model_data may be encrypted, please pass in correct key.";
} else {
MS_LOG(ERROR) << "Load model failed.";
}
return kMEInvalidInput;
}
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
return kSuccess;
} else if (model_type == kOM) {
*graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
return kSuccess;
}
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
return kMEInvalidInput;
}
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key, const std::vector<char> &dec_mode) {
if (graph == nullptr) {
MS_LOG(ERROR) << "Output args graph is nullptr.";
return kMEInvalidInput;
}
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = nullptr;
try {
if (dec_key.len > dec_key.max_key_len) {
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
return kMEInvalidInput;
} else {
size_t plain_data_size;
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
data_size, dec_key.key, dec_key.len, dec_mode_str);
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size);
}
} catch (const std::exception &) {
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode.";
return kMEInvalidInput;
}
@ -103,7 +144,48 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = LoadMindIR(file_path);
if (anf_graph == nullptr) {
MS_LOG(ERROR) << "Load model failed.";
if (IsCipherFile(file_path)) {
MS_LOG(ERROR) << "Load model failed. The file may be encrypted, please pass in correct key.";
} else {
MS_LOG(ERROR) << "Load model failed.";
}
return kMEInvalidInput;
}
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
return kSuccess;
} else if (model_type == kOM) {
Buffer data = ReadFile(file_path);
if (data.Data() == nullptr) {
MS_LOG(ERROR) << "Read file " << file_path << " failed.";
return kMEInvalidInput;
}
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
return kSuccess;
}
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
return kMEInvalidInput;
}
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
const std::vector<char> &dec_mode) {
if (graph == nullptr) {
MS_LOG(ERROR) << "Output args graph is nullptr.";
return kMEInvalidInput;
}
std::string file_path = CharToString(file);
if (model_type == kMindIR) {
FuncGraphPtr anf_graph;
if (dec_key.len > dec_key.max_key_len) {
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
return kMEInvalidInput;
} else {
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
anf_graph = LoadMindIR(file_path, false, dec_key.key, dec_key.len, dec_mode_str);
}
if (anf_graph == nullptr) {
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode";
return kMEInvalidInput;
}
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));

View File

@ -28,7 +28,6 @@
#include "utils/mpi/mpi_config.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/costmodel_context.h"
#include "crypto/crypto_pybind.h"
#ifdef ENABLE_GPU_COLLECTIVE
#include "runtime/device/gpu/distribution/collective_init.h"
#else
@ -112,7 +111,9 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph.");
(py::object)
m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
@ -367,7 +368,7 @@ PYBIND11_MODULE(_c_expression, m) {
.def(py::init())
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
(void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
}

View File

@ -18,6 +18,7 @@
#include "pipeline/jit/pipeline.h"
#include <memory>
#include <sstream>
#include <map>
#include <unordered_map>
@ -52,6 +53,7 @@
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/framework/actor/actor_common.h"
#include "utils/crypto.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/constants.h"
@ -189,7 +191,7 @@ void GetCachedFuncGraph(const ResourcePtr &resource) {
return;
}
MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
FuncGraphPtr fg = LoadMindIR(realpath.value());
FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value());
if (fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
}
@ -1185,7 +1187,10 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
#endif
}
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); }
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
const std::string &dec_mode) {
return mindspore::LoadMindIR(file_name, false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode);
}
void ReleaseGeTsd() {
auto context_ptr = MsContext::GetInstance();
@ -1282,5 +1287,24 @@ void ClearResAtexit() {
parse::CleanDataClassToClassMap();
trace::ClearTraceStack();
}
py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const size_t key_len, std::string enc_mode) {
size_t encrypt_len;
auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
reinterpret_cast<Byte *>(key), key_len, enc_mode);
auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
return py_encrypt_data;
}
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_len, std::string dec_mode) {
size_t decrypt_len;
auto decrypt_data =
mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
return py_decrypt_data;
}
bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
} // namespace pipeline
} // namespace mindspore

View File

@ -143,7 +143,7 @@ void ClearResAtexit();
void ReleaseGeTsd();
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
FuncGraphPtr LoadMindIR(const std::string &file_name);
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode);
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
@ -157,6 +157,10 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const size_t key_len, std::string enc_mode);
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_len, std::string dec_mode);
bool PyIsCipherFile(const std::string &file_path);
} // namespace pipeline
} // namespace mindspore

View File

@ -48,3 +48,7 @@ endif()
if(USE_GLOG)
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
endif()
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(mindspore_core PRIVATE mindspore::crypto -pthread)
endif()

View File

@ -26,6 +26,7 @@
#include "load_mindir/load_model.h"
#include "load_mindir/anf_model_parser.h"
#include "proto/mind_ir.pb.h"
#include "utils/crypto.h"
using std::string;
using std::vector;
@ -121,7 +122,55 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) {
bool ParseModelProto(mind_ir::ModelProto *model, std::string path, const unsigned char *dec_key, const size_t key_len,
const std::string &dec_mode) {
if (dec_key != nullptr) {
size_t plain_len;
auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
if (plain_data == nullptr) {
MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
return false;
}
if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), plain_len)) {
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
return false;
}
} else {
std::fstream input_graph(path, std::ios::in | std::ios::binary);
if (!input_graph || !model->ParseFromIstream(&input_graph)) {
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
return false;
}
}
return true;
}
bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigned char *dec_key, const size_t key_len,
const std::string &dec_mode) {
if (dec_key != nullptr) {
size_t plain_len;
auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
if (plain_data == nullptr) {
MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
return false;
}
if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), plain_len)) {
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
"dec_key or dec_mode";
return false;
}
} else {
std::fstream input_param(path, std::ios::in | std::ios::binary);
if (!input_param || !graph->ParseFromIstream(&input_param)) {
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
return false;
}
}
return true;
}
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
const size_t key_len, const std::string &dec_mode) {
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
char abs_path_buff[PATH_MAX];
char abs_path[PATH_MAX];
@ -136,18 +185,10 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
}
#endif
// Read graph
std::fstream input_graph(abs_path_buff, std::ios::in | std::ios::binary);
mind_ir::ModelProto origin_model;
if (!input_graph) {
MS_LOG(ERROR) << "Failed to open file: " << file_name;
if (!ParseModelProto(&origin_model, std::string(abs_path_buff), dec_key, key_len, dec_mode)) {
return nullptr;
}
if (!origin_model.ParseFromIstream(&input_graph)) {
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
return nullptr;
}
// Load parameter into graph
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
@ -171,10 +212,8 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
int file_size = files.size();
mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
for (auto file_index = 0; file_index < file_size; file_index++) {
std::fstream input_param(files[file_index], std::ios::in | std::ios::binary);
mind_ir::GraphProto param_graph;
if (!input_param || !param_graph.ParseFromIstream(&input_param)) {
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
if (!ParseGraphProto(&param_graph, files[file_index], dec_key, key_len, dec_mode)) {
return nullptr;
}

View File

@ -23,7 +23,9 @@
#include "ir/func_graph.h"
namespace mindspore {
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false);
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
const std::string &dec_mode = std::string("AES-GCM"));
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore

View File

@ -0,0 +1,374 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <regex>
#include <vector>
#include <algorithm>
#include "utils/crypto.h"
#include "utils/log_adapter.h"
#if not defined(_WIN32)
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#endif
namespace mindspore {
void intToByte(Byte *byte, const int32_t &n) {
memset(byte, 0, sizeof(Byte) * 4);
byte[0] = (Byte)(0xFF & n);
byte[1] = (Byte)((0xFF00 & n) >> 8);
byte[2] = (Byte)((0xFF0000 & n) >> 16);
byte[3] = (Byte)((0xFF000000 & n) >> 24);
}
int32_t ByteToint(const Byte *byteArray) {
int32_t res = byteArray[0] & 0xFF;
res |= ((byteArray[1] << 8) & 0xFF00);
res |= ((byteArray[2] << 16) & 0xFF0000);
res += ((byteArray[3] << 24) & 0xFF000000);
return res;
}
bool IsCipherFile(const std::string &file_path) {
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
if (!fid) {
MS_LOG(ERROR) << "Failed to open file " << file_path;
return false;
}
std::vector<char> int_buf(4);
fid.read(int_buf.data(), sizeof(int32_t));
fid.close();
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
return flag == MAGIC_NUM;
}
bool IsCipherFile(const Byte *model_data) {
std::vector<char> int_buf(4);
memcpy(int_buf.data(), model_data, 4);
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
return flag == MAGIC_NUM;
}
#if defined(_WIN32)
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
const size_t key_len, const std::string &dec_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
const size_t key_len, const std::string &dec_mode) {
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
}
#else
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, std::vector<Byte> *iv,
std::vector<Byte> *cipher_data) {
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
Byte buf[4];
memcpy(buf, encrypt_data, 4);
auto iv_len = ByteToint(buf);
memcpy(buf, encrypt_data + iv_len + 4, 4);
auto cipher_len = ByteToint(buf);
if (iv_len <= 0 || cipher_len <= 0 || iv_len + cipher_len + 8 != encrypt_len) {
MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false;
}
(*iv).resize(iv_len);
memcpy((*iv).data(), encrypt_data + 4, iv_len);
(*cipher_data).resize(cipher_len);
memcpy((*cipher_data).data(), encrypt_data + iv_len + 8, cipher_len);
return true;
}
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
std::smatch results;
std::regex re("([A-Z]{3})-([A-Z]{3})");
if (!std::regex_match(mode.c_str(), re)) {
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
return false;
}
std::regex_search(mode, results, re);
*alg_mode = results[1];
*work_mode = results[2];
return true;
}
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
int flag) {
int ret = 0;
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (work_mode != "GCM" && work_mode != "CBC") {
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
return nullptr;
}
const EVP_CIPHER *(*funcPtr)() = nullptr;
if (work_mode == "GCM") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_gcm;
break;
case 24:
funcPtr = EVP_aes_192_gcm;
break;
case 32:
funcPtr = EVP_aes_256_gcm;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
} else if (work_mode == "CBC") {
switch (key_len) {
case 16:
funcPtr = EVP_aes_128_cbc;
break;
case 24:
funcPtr = EVP_aes_192_cbc;
break;
case 32:
funcPtr = EVP_aes_256_cbc;
break;
default:
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
}
}
if (flag == 0) {
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
} else if (flag == 1) {
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
return nullptr;
}
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
return ctx;
}
bool _BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, Byte *plain_data, const size_t plain_len,
const Byte *key, const int32_t key_len, const std::string &enc_mode) {
int32_t cipher_len = 0;
int32_t iv_len = AES_BLOCK_SIZE;
Byte *iv = new Byte[iv_len];
RAND_bytes(iv, sizeof(Byte) * iv_len);
Byte *iv_cpy = new Byte[16];
memcpy(iv_cpy, iv, 16);
int32_t flen = 0;
std::string alg_mode;
std::string work_mode;
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
return false;
}
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
std::vector<Byte> cipher_data(plain_len + 16);
int ret = EVP_EncryptUpdate(ctx, cipher_data.data(), &cipher_len, plain_data, plain_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
return false;
}
if (work_mode == "CBC") {
EVP_EncryptFinal_ex(ctx, cipher_data.data() + cipher_len, &flen);
cipher_len += flen;
}
EVP_CIPHER_CTX_free(ctx);
size_t cur = 0;
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
std::vector<Byte> byte_buf(4);
intToByte(byte_buf.data(), *encrypt_data_len);
memcpy(encrypt_data + cur, byte_buf.data(), 4);
cur += 4;
intToByte(byte_buf.data(), iv_len);
memcpy(encrypt_data + cur, byte_buf.data(), 4);
cur += 4;
memcpy(encrypt_data + cur, iv_cpy, iv_len);
cur += iv_len;
intToByte(byte_buf.data(), cipher_len);
memcpy(encrypt_data + cur, byte_buf.data(), 4);
cur += 4;
memcpy(encrypt_data + cur, cipher_data.data(), cipher_len);
*encrypt_data_len += 4;
return true;
}
bool _BlockDecrypt(Byte *plain_data, int32_t *plain_len, Byte *encrypt_data, const size_t encrypt_len, const Byte *key,
const int32_t key_len, const std::string &dec_mode) {
std::string alg_mode;
std::string work_mode;
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
return false;
}
std::vector<Byte> iv;
std::vector<Byte> cipher_data;
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
return false;
}
int ret = 0;
int mlen = 0;
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv.data(), 1);
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
ret = EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), cipher_data.size());
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
return false;
}
if (work_mode == "CBC") {
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
}
EVP_CIPHER_CTX_free(ctx);
return true;
}
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
size_t cur_pos = 0;
size_t block_enc_len = 0;
size_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
size_t block_enc_buf_len = MAX_BLOCK_SIZE + 100;
std::vector<Byte> int_buf(4);
std::vector<Byte> block_buf(MAX_BLOCK_SIZE);
std::vector<Byte> block_enc_buf(block_enc_buf_len);
auto encrypt_data = std::make_unique<Byte[]>(encrypt_buf_len);
*encrypt_len = 0;
while (cur_pos < plain_len) {
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - cur_pos);
memcpy(block_buf.data(), plain_data + cur_pos, cur_block_size);
if (!_BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf.data(), cur_block_size, key, key_len,
enc_mode)) {
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
}
intToByte(int_buf.data(), MAGIC_NUM);
memcpy(encrypt_data.get() + *encrypt_len, int_buf.data(), sizeof(int32_t));
*encrypt_len += sizeof(int32_t);
memcpy(encrypt_data.get() + *encrypt_len, block_enc_buf.data(), block_enc_len);
*encrypt_len += block_enc_len;
cur_pos += cur_block_size;
}
return encrypt_data;
}
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
const size_t key_len, const std::string &dec_mode) {
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
if (!fid) {
MS_EXCEPTION(ValueError) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file.";
}
fid.seekg(0, std::ios_base::end);
size_t file_size = fid.tellg();
fid.clear();
fid.seekg(0);
std::vector<char> block_buf(MAX_BLOCK_SIZE * 2);
std::vector<char> int_buf(4);
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE * 2);
auto decrypt_data = std::make_unique<Byte[]>(file_size);
int32_t decrypt_block_len;
*decrypt_len = 0;
while (static_cast<size_t>(fid.tellg()) < file_size) {
fid.read(int_buf.data(), sizeof(int32_t));
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
if (cipher_flag != MAGIC_NUM) {
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
<< "\" is not an encrypted file and cannot be decrypted";
}
fid.read(int_buf.data(), sizeof(int32_t));
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
fid.read(block_buf.data(), sizeof(char) * block_size);
if (!(_BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
block_size, key, key_len, dec_mode))) {
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
}
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), decrypt_block_len);
*decrypt_len += decrypt_block_len;
}
fid.close();
return decrypt_data;
}
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
const size_t key_len, const std::string &dec_mode) {
std::vector<char> block_buf(MAX_BLOCK_SIZE * 2);
std::vector<char> int_buf(4);
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE * 2);
auto decrypt_data = std::make_unique<Byte[]>(data_size);
int32_t decrypt_block_len;
size_t cur_pos = 0;
*decrypt_len = 0;
while (cur_pos < data_size) {
memcpy(int_buf.data(), model_data + cur_pos, 4);
cur_pos += 4;
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
if (cipher_flag != MAGIC_NUM) {
MS_EXCEPTION(ValueError) << "model_data is not encrypted and therefore cannot be decrypted.";
}
memcpy(int_buf.data(), model_data + cur_pos, 4);
cur_pos += 4;
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
memcpy(block_buf.data(), model_data + cur_pos, block_size);
cur_pos += block_size;
if (!(_BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
block_size, key, key_len, dec_mode))) {
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
}
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), decrypt_block_len);
*decrypt_len += decrypt_block_len;
}
return decrypt_data;
}
#endif
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_CRYPTO_H
#define MINDSPORE_CORE_UTILS_CRYPTO_H
#include <string>
#include <memory>
typedef unsigned char Byte;
namespace mindspore {
const size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
const size_t key_len, const std::string &enc_mode);
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
const size_t key_len, const std::string &dec_mode);
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
const size_t key_len, const std::string &dec_mode);
bool IsCipherFile(const std::string &file_path);
bool IsCipherFile(const Byte *model_data);
} // namespace mindspore
#endif

View File

@ -166,6 +166,7 @@ include(${TOP_DIR}/cmake/utils.cmake)
include(${TOP_DIR}/cmake/dependency_utils.cmake)
include(${TOP_DIR}/cmake/dependency_securec.cmake)
include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake)
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
if(MSLITE_GPU_BACKEND STREQUAL opencl)
include(${TOP_DIR}/cmake/external_libs/opencl.cmake)
endif()

View File

@ -46,11 +46,23 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
return kSuccess;
}
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key, const std::vector<char> &dec_mode) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kLiteError;
}
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kLiteError;
}
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
const std::vector<char> &dec_mode) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kLiteError;
}
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed;

View File

@ -53,6 +53,13 @@ Flags::Flags() {
"whether the model is going to be trained on device. "
"true | false",
"false");
AddFlag(&Flags::dec_key, "decryptKey",
"The key used to decrypt the file, expressed in hexadecimal characters. Only valid when fmkIn is 'MINDIR'",
"");
AddFlag(&Flags::dec_mode, "decryptMode",
"Decryption method for the MindIR file. Only valid when dec_key is set."
"AES-GCM | AES-CBC",
"AES-GCM");
}
int Flags::InitInputOutputDataType() {

View File

@ -99,6 +99,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
bool trainModel = false;
std::vector<std::string> pluginsPath;
bool disableFusion = false;
std::string dec_key = "";
std::string dec_mode = "AES-GCM";
};
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);

View File

@ -17,6 +17,7 @@
#include "tools/converter/import/mindspore_importer.h"
#include <memory>
#include <vector>
#include <regex>
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/import/primitive_adjust.h"
#include "tools/converter/import/mindir_adjust.h"
@ -137,9 +138,53 @@ STATUS MindsporeImporter::HardCodeMindir(const CNodePtr &conv_node, const FuncGr
return lite::RET_OK;
}
size_t MindsporeImporter::Hex2ByteArray(std::string hex_str, unsigned char *byte_array, size_t max_len) {
std::regex r("[0-9a-fA-F]+");
if (!std::regex_match(hex_str, r)) {
MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
return 0;
}
if (hex_str.size() % 2 == 1) {
hex_str.push_back('0');
}
size_t byte_len = hex_str.size() / 2;
if (byte_len > max_len) {
MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: 64";
return 0;
}
for (size_t i = 0; i < byte_len; ++i) {
size_t p = i * 2;
if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
byte_array[i] = hex_str[p] - 'a' + 10;
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
byte_array[i] = hex_str[p] - 'A' + 10;
} else {
byte_array[i] = hex_str[p] - '0';
}
if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - 'a' + 10);
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - 'A' + 10);
} else {
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - '0');
}
}
return byte_len;
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
quant_type_ = flag.quantType;
auto func_graph = LoadMindIR(flag.modelFile);
FuncGraphPtr func_graph;
if (flag.dec_key.size() != 0) {
unsigned char key[32];
const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32);
if (key_len == 0) {
return nullptr;
}
func_graph = LoadMindIR(flag.modelFile, false, key, key_len, flag.dec_mode);
} else {
func_graph = LoadMindIR(flag.modelFile);
}
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
return nullptr;

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_TOOLS_IMPORT_MINDSPORE_IMPORTER_H_
#define MINDSPORE_LITE_TOOLS_IMPORT_MINDSPORE_IMPORTER_H_
#include <string>
#include "tools/converter/converter_flags.h"
#include "load_mindir/load_model.h"
@ -32,6 +33,7 @@ class MindsporeImporter {
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
STATUS HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph);
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
size_t Hex2ByteArray(std::string hex_str, unsigned char *byte_array, size_t max_len);
};
} // namespace mindspore::lite

View File

@ -304,7 +304,7 @@ def _check_append_dict(append_dict):
return append_dict
def load(file_name):
def load(file_name, **kwargs):
"""
Load MindIR.
@ -314,6 +314,11 @@ def load(file_name):
Args:
file_name (str): MindIR file name.
kwargs (dict): Configuration options dictionary.
- dec_key: Byte type key used for decryption. Tha valid length is 16, 24, or 32.
- dec_mode: Specifies the decryption mode, take effect when dec_key is set. Option: 'AES-GCM' | 'AES-CBC'.
Default: 'AES-GCM'.
Returns:
Object, a compiled graph that can executed by `GraphCell`.
@ -341,8 +346,19 @@ def load(file_name):
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
logger.info("Execute the process of loading mindir.")
graph = load_mindir(file_name)
if 'dec_key' in kwargs.keys():
dec_key = Validator.check_isinstance('dec_key', kwargs['dec_key'], bytes)
dec_mode = 'AES-GCM'
if 'dec_mode' in kwargs.keys():
dec_mode = Validator.check_isinstance('dec_mode', kwargs['dec_mode'], str)
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode)
else:
graph = load_mindir(file_name)
if graph is None:
if _is_cipher_file(file_name):
raise RuntimeError("Load MindIR failed. The file may be encrypted, please pass in the "
"correct dec_key and dec_mode.")
raise RuntimeError("Load MindIR failed.")
return graph
@ -392,7 +408,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
except BaseException as e:
if _is_cipher_file(ckpt_file_name):
logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the "
"dec_key.", ckpt_file_name)
"correct dec_key.", ckpt_file_name)
else:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \
ckpt_file_name)
@ -686,6 +702,9 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
Default: 127.5.
- std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network.
Default: 127.5.
- enc_key: Byte type key used for encryption. Tha valid length is 16, 24, or 32.
- enc_mode: Specifies the encryption mode, take effect when enc_key is set. Option: 'AES-GCM' | 'AES-CBC'.
Default: 'AES-GCM'.
Examples:
>>> import numpy as np
@ -702,10 +721,20 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
Validator.check_file_name_by_regular(file_name)
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
_export(net, file_name, file_format, *inputs)
if 'enc_key' in kwargs.keys():
if file_format != 'MINDIR':
raise ValueError(f"enc_key can be passed in only when file_format=='MINDIR', but got {file_format}")
enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes)
enc_mode = 'AES-GCM'
if 'enc_mode' in kwargs.keys():
enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode)
else:
_export(net, file_name, file_format, *inputs)
def _export(net, file_name, file_format, *inputs):
def _export(net, file_name, file_format, *inputs, **kwargs):
"""
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
"""
@ -744,13 +773,13 @@ def _export(net, file_name, file_format, *inputs):
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(onnx_stream)
elif file_format == 'MINDIR':
_save_mindir(net, file_name, *inputs)
_save_mindir(net, file_name, *inputs, **kwargs)
if is_dump_onnx_in_training:
net.set_train(mode=True)
def _save_mindir(net, file_name, *inputs):
def _save_mindir(net, file_name, *inputs, **kwargs):
"""Save MindIR format file."""
model = mindir_model()
@ -764,7 +793,7 @@ def _save_mindir(net, file_name, *inputs):
model.ParseFromString(mindir_stream)
save_together = _mindir_save_together(net_dict, model)
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
if save_together:
for param_proto in model.graph.parameter:
param_name = param_proto.name[param_proto.name.find(":")+1:]
@ -781,7 +810,11 @@ def _save_mindir(net, file_name, *inputs):
os.makedirs(dirname, exist_ok=True)
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(model.SerializeToString())
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(model_string)
else:
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
@ -815,7 +848,11 @@ def _save_mindir(net, file_name, *inputs):
data_file_name = data_path + "/" + "data_" + str(index)
with open(data_file_name, "ab") as f:
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(graphproto.SerializeToString())
graph_string = graphproto.SerializeToString()
if is_encrypt():
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
f.write(graph_string)
index += 1
data_size = 0
del graphproto.parameter[:]
@ -824,14 +861,22 @@ def _save_mindir(net, file_name, *inputs):
data_file_name = data_path + "/" + "data_" + str(index)
with open(data_file_name, "ab") as f:
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(graphproto.SerializeToString())
graph_string = graphproto.SerializeToString()
if is_encrypt():
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(graph_string)
# save graph
del model.graph.parameter[:]
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
with open(graph_file_name, 'wb') as f:
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
f.write(model.SerializeToString())
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
kwargs['enc_mode'])
f.write(model_string)
def _mindir_save_together(net_dict, model):

View File

@ -16,6 +16,7 @@
import os
import platform
import stat
import secrets
from unittest import mock
import numpy as np
@ -254,7 +255,7 @@ def test_checkpoint_save_ckpt_with_encryption():
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
enc_key=os.urandom(16),
enc_key=secrets.token_bytes(16),
enc_mode="AES-GCM")
ckpt_cb = ModelCheckpoint(config=train_config)
cb_params = _InternalCallbackParam()

View File

@ -17,6 +17,7 @@ import os
import platform
import stat
import time
import secrets
import numpy as np
import pytest
@ -32,7 +33,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
export, _save_graph
export, _save_graph, load
from ..ut_filter import non_graph_engine
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
@ -332,7 +333,7 @@ def test_save_and_load_checkpoint_for_network_with_encryption():
loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt)
key = os.urandom(16)
key = secrets.token_bytes(16)
mode = "AES-GCM"
ckpt_path = "./encrypt_ckpt.ckpt"
if platform.system().lower() == "windows":
@ -383,6 +384,16 @@ def test_mindir_export():
export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
@non_graph_engine
def test_mindir_export_and_load_with_encryption():
net = MYNET()
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
key = secrets.token_bytes(16)
export(net, input_data, file_name="./me_cipher_binary_export.mindir", file_format="MINDIR", enc_key=key)
load("./me_cipher_binary_export.mindir", dec_key=key)
class PrintNet(nn.Cell):
def __init__(self):
super(PrintNet, self).__init__()