Add encryption support for mindir
This commit is contained in:
parent
941835dcf7
commit
5b9b46224b
|
@ -27,21 +27,47 @@
|
|||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
using Key = struct Key {
|
||||
const size_t max_key_len = 32;
|
||||
size_t len;
|
||||
unsigned char key[32];
|
||||
Key(): len(0) {}
|
||||
};
|
||||
|
||||
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
|
||||
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::string &dec_mode);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
|
||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::string &dec_mode);
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
|
||||
private:
|
||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
|
||||
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode);
|
||||
};
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
|
||||
return Load(StringToChar(file), model_type, graph);
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::string &dec_mode) {
|
||||
return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
|
|
@ -218,7 +218,6 @@ set(SUB_COMP
|
|||
pipeline/jit
|
||||
pipeline/pynative
|
||||
common debug pybind_api utils vm profiler ps
|
||||
crypto
|
||||
)
|
||||
|
||||
foreach(_comp ${SUB_COMP})
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES})
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(_mindspore_crypto_obj mindspore::crypto)
|
||||
endif()
|
|
@ -1,347 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "crypto/crypto.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace crypto {
|
||||
int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; }
|
||||
|
||||
void *intToByte(Byte *byte, const int32_t &n) {
|
||||
memset(byte, 0, sizeof(Byte) * 4);
|
||||
byte[0] = (Byte)(0xFF & n);
|
||||
byte[1] = (Byte)((0xFF00 & n) >> 8);
|
||||
byte[2] = (Byte)((0xFF0000 & n) >> 16);
|
||||
byte[3] = (Byte)((0xFF000000 & n) >> 24);
|
||||
return byte;
|
||||
}
|
||||
|
||||
int32_t ByteToint(const Byte *byteArray) {
|
||||
int32_t res = byteArray[0] & 0xFF;
|
||||
res |= ((byteArray[1] << 8) & 0xFF00);
|
||||
res |= ((byteArray[2] << 16) & 0xFF0000);
|
||||
res += ((byteArray[3] << 24) & 0xFF000000);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool IsCipherFile(std::string file_path) {
|
||||
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
|
||||
if (!fid) {
|
||||
return false;
|
||||
}
|
||||
char *int_buf = new char[4];
|
||||
fid.read(int_buf, sizeof(int32_t));
|
||||
fid.close();
|
||||
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||
delete[] int_buf;
|
||||
return flag == MAGIC_NUM;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||
const std::string &enc_mode) {
|
||||
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||
}
|
||||
|
||||
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||
}
|
||||
#else
|
||||
|
||||
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
|
||||
Byte **cipher_data, int32_t *cipher_len) {
|
||||
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||
Byte buf[4];
|
||||
memcpy_s(buf, 4, encrypt_data, 4);
|
||||
*iv_len = ByteToint(buf);
|
||||
memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
|
||||
*cipher_len = ByteToint(buf);
|
||||
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
|
||||
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||
return false;
|
||||
}
|
||||
*iv = new Byte[*iv_len];
|
||||
memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
|
||||
*cipher_data = new Byte[*cipher_len];
|
||||
memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
|
||||
std::smatch results;
|
||||
std::regex re("([A-Z]{3})-([A-Z]{3})");
|
||||
if (!std::regex_match(mode.c_str(), re)) {
|
||||
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
|
||||
return false;
|
||||
}
|
||||
std::regex_search(mode, results, re);
|
||||
*alg_mode = results[1];
|
||||
*work_mode = results[2];
|
||||
return true;
|
||||
}
|
||||
|
||||
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
|
||||
int flag) {
|
||||
int ret = 0;
|
||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||
if (work_mode != "GCM" && work_mode != "CBC") {
|
||||
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const EVP_CIPHER *(*funcPtr)() = nullptr;
|
||||
if (work_mode == "GCM") {
|
||||
switch (key_len) {
|
||||
case 16:
|
||||
funcPtr = EVP_aes_128_gcm;
|
||||
break;
|
||||
case 24:
|
||||
funcPtr = EVP_aes_192_gcm;
|
||||
break;
|
||||
case 32:
|
||||
funcPtr = EVP_aes_256_gcm;
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||
}
|
||||
} else if (work_mode == "CBC") {
|
||||
switch (key_len) {
|
||||
case 16:
|
||||
funcPtr = EVP_aes_128_cbc;
|
||||
break;
|
||||
case 24:
|
||||
funcPtr = EVP_aes_192_cbc;
|
||||
break;
|
||||
case 32:
|
||||
funcPtr = EVP_aes_256_cbc;
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||
}
|
||||
}
|
||||
|
||||
if (flag == 0) {
|
||||
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
} else if (flag == 1) {
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
}
|
||||
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
bool _BlockEncrypt(Byte *encrypt_data, const int64_t encrypt_data_buf_len, int64_t *encrypt_data_len, Byte *plain_data,
|
||||
const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) {
|
||||
int32_t cipher_len = 0;
|
||||
|
||||
int32_t iv_len = AES_BLOCK_SIZE;
|
||||
Byte *iv = new Byte[iv_len];
|
||||
RAND_bytes(iv, sizeof(Byte) * iv_len);
|
||||
|
||||
Byte *iv_cpy = new Byte[16];
|
||||
memcpy_s(iv_cpy, 16, iv, 16);
|
||||
|
||||
int32_t ret = 0;
|
||||
int32_t flen = 0;
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
}
|
||||
|
||||
Byte *cipher_data;
|
||||
cipher_data = new Byte[plain_len + 16];
|
||||
ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
||||
delete[] cipher_data;
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen);
|
||||
cipher_len += flen;
|
||||
}
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
|
||||
int64_t cur = 0;
|
||||
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
|
||||
|
||||
Byte *byte_buf = new Byte[4];
|
||||
intToByte(byte_buf, *encrypt_data_len);
|
||||
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
|
||||
cur += 4;
|
||||
intToByte(byte_buf, iv_len);
|
||||
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
|
||||
cur += 4;
|
||||
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, iv_cpy, iv_len);
|
||||
cur += iv_len;
|
||||
intToByte(byte_buf, cipher_len);
|
||||
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, byte_buf, 4);
|
||||
cur += 4;
|
||||
memcpy_s(encrypt_data + cur, encrypt_data_buf_len - cur, cipher_data, cipher_len);
|
||||
*encrypt_data_len += 4;
|
||||
|
||||
delete[] byte_buf;
|
||||
delete[] cipher_data;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool _BlockDecrypt(Byte *plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
|
||||
const int32_t key_len, const std::string &dec_mode) {
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
|
||||
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t iv_len = 0;
|
||||
int32_t cipher_len = 0;
|
||||
Byte *iv = NULL;
|
||||
Byte *cipher_data = NULL;
|
||||
|
||||
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int ret = 0;
|
||||
int mlen = 0;
|
||||
|
||||
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
}
|
||||
ret = EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data, cipher_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
*plain_len += mlen;
|
||||
}
|
||||
delete[] iv;
|
||||
delete[] cipher_data;
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||
const std::string &enc_mode) {
|
||||
int64_t cur_pos = 0;
|
||||
int64_t block_enc_len = 0;
|
||||
int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
|
||||
int64_t block_enc_buf_len = MAX_BLOCK_SIZE + 100;
|
||||
Byte *byte_buf = new Byte[4];
|
||||
Byte *encrypt_data = new Byte[encrypt_buf_len];
|
||||
Byte *block_buf = new Byte[MAX_BLOCK_SIZE];
|
||||
Byte *block_enc_buf = new Byte[block_enc_buf_len];
|
||||
|
||||
*encrypt_len = 0;
|
||||
while (cur_pos < plain_len) {
|
||||
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
|
||||
memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);
|
||||
|
||||
if (!_BlockEncrypt(block_enc_buf, block_enc_buf_len, &block_enc_len, block_buf, cur_block_size, key, key_len,
|
||||
enc_mode)) {
|
||||
delete[] byte_buf;
|
||||
delete[] block_buf;
|
||||
delete[] block_enc_buf;
|
||||
delete[] encrypt_data;
|
||||
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||
}
|
||||
intToByte(byte_buf, MAGIC_NUM);
|
||||
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, byte_buf, sizeof(int32_t));
|
||||
*encrypt_len += sizeof(int32_t);
|
||||
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
|
||||
*encrypt_len += block_enc_len;
|
||||
cur_pos += cur_block_size;
|
||||
}
|
||||
delete[] byte_buf;
|
||||
delete[] block_buf;
|
||||
delete[] block_enc_buf;
|
||||
return encrypt_data;
|
||||
}
|
||||
|
||||
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
char *block_buf = new char[MAX_BLOCK_SIZE * 2];
|
||||
char *int_buf = new char[4];
|
||||
Byte *decrypt_block_buf = new Byte[MAX_BLOCK_SIZE * 2];
|
||||
int32_t decrypt_block_len;
|
||||
|
||||
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
|
||||
if (!fid) {
|
||||
MS_EXCEPTION(ValueError) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file.";
|
||||
}
|
||||
fid.seekg(0, std::ios_base::end);
|
||||
int64_t file_size = fid.tellg();
|
||||
fid.clear();
|
||||
fid.seekg(0);
|
||||
Byte *decrypt_data = new Byte[file_size];
|
||||
|
||||
*decrypt_len = 0;
|
||||
while (fid.tellg() < file_size) {
|
||||
fid.read(int_buf, sizeof(int32_t));
|
||||
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||
if (cipher_flag != MAGIC_NUM) {
|
||||
delete[] block_buf;
|
||||
delete[] int_buf;
|
||||
delete[] decrypt_block_buf;
|
||||
delete[] decrypt_data;
|
||||
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
|
||||
<< "\"is not an encrypted file and cannot be decrypted";
|
||||
}
|
||||
fid.read(int_buf, sizeof(int32_t));
|
||||
|
||||
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
|
||||
fid.read(block_buf, sizeof(char) * block_size);
|
||||
if (!(_BlockDecrypt(decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
|
||||
key_len, dec_mode))) {
|
||||
delete[] block_buf;
|
||||
delete[] int_buf;
|
||||
delete[] decrypt_block_buf;
|
||||
delete[] decrypt_data;
|
||||
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
}
|
||||
auto destMax = Min(file_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
|
||||
memcpy_s(decrypt_data + *decrypt_len, destMax, decrypt_block_buf, decrypt_block_len);
|
||||
*decrypt_len += decrypt_block_len;
|
||||
}
|
||||
fid.close();
|
||||
delete[] block_buf;
|
||||
delete[] int_buf;
|
||||
delete[] decrypt_block_buf;
|
||||
return decrypt_data;
|
||||
}
|
||||
#endif
|
||||
} // namespace crypto
|
||||
} // namespace mindspore
|
|
@ -1,46 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
|
||||
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H
|
||||
|
||||
#if not defined(_WIN32)
|
||||
#include <openssl/aes.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rand.h>
|
||||
#endif
|
||||
|
||||
#include <stdio.h>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
typedef unsigned char Byte;
|
||||
|
||||
namespace mindspore {
|
||||
namespace crypto {
|
||||
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||
|
||||
Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
|
||||
const std::string &enc_mode);
|
||||
Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
|
||||
const std::string &dec_mode);
|
||||
bool IsCipherFile(const std::string file_path);
|
||||
} // namespace crypto
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -1,41 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "crypto/crypto_pybind.h"
|
||||
namespace mindspore {
|
||||
namespace crypto {
|
||||
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) {
|
||||
int64_t encrypt_len;
|
||||
char *encrypt_data;
|
||||
encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
|
||||
reinterpret_cast<Byte *>(key), key_len, enc_mode));
|
||||
auto py_encrypt_data = py::bytes(encrypt_data, encrypt_len);
|
||||
delete[] encrypt_data;
|
||||
return py_encrypt_data;
|
||||
}
|
||||
|
||||
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) {
|
||||
int64_t decrypt_len;
|
||||
char *decrypt_data;
|
||||
decrypt_data = reinterpret_cast<char *>(
|
||||
Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode));
|
||||
auto py_decrypt_data = py::bytes(decrypt_data, decrypt_len);
|
||||
delete[] decrypt_data;
|
||||
return py_decrypt_data;
|
||||
}
|
||||
bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); }
|
||||
} // namespace crypto
|
||||
} // namespace mindspore
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
|
||||
#define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H
|
||||
#include "crypto/crypto.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
namespace crypto {
|
||||
py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode);
|
||||
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode);
|
||||
bool PyIsCipherFile(std::string file_path);
|
||||
} // namespace crypto
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -18,6 +18,7 @@
|
|||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
#include "utils/crypto.h"
|
||||
|
||||
namespace mindspore {
|
||||
static Buffer ReadFile(const std::string &file) {
|
||||
|
@ -78,7 +79,47 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size);
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
|
||||
MS_LOG(ERROR) << "Load model failed. The model_data may be encrypted, please pass in correct key.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
}
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
return kSuccess;
|
||||
} else if (model_type == kOM) {
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
|
||||
return kMEInvalidInput;
|
||||
} else {
|
||||
size_t plain_data_size;
|
||||
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
|
||||
auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
|
||||
data_size, dec_key.key, dec_key.len, dec_mode_str);
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size);
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
|
@ -103,7 +144,48 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = LoadMindIR(file_path);
|
||||
if (anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
if (IsCipherFile(file_path)) {
|
||||
MS_LOG(ERROR) << "Load model failed. The file may be encrypted, please pass in correct key.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Load model failed.";
|
||||
}
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
return kSuccess;
|
||||
} else if (model_type == kOM) {
|
||||
Buffer data = ReadFile(file_path);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(ERROR) << "Read file " << file_path << " failed.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported ModelType " << model_type;
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Output args graph is nullptr.";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
|
||||
std::string file_path = CharToString(file);
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph;
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
MS_LOG(ERROR) << "The key length exceeds maximum length: 32.";
|
||||
return kMEInvalidInput;
|
||||
} else {
|
||||
std::string dec_mode_str(dec_mode.begin(), dec_mode.end());
|
||||
anf_graph = LoadMindIR(file_path, false, dec_key.key, dec_key.len, dec_mode_str);
|
||||
}
|
||||
if (anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load model failed. Please check the valid of dec_key and dec_mode";
|
||||
return kMEInvalidInput;
|
||||
}
|
||||
*graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include "utils/mpi/mpi_config.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
#include "crypto/crypto_pybind.h"
|
||||
#ifdef ENABLE_GPU_COLLECTIVE
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#else
|
||||
|
@ -112,7 +111,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
|
||||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
(py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph.");
|
||||
(py::object)
|
||||
m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
@ -367,7 +368,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def(py::init())
|
||||
.def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
|
||||
|
||||
(void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data.");
|
||||
(void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data.");
|
||||
(void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
@ -52,6 +53,7 @@
|
|||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "utils/crypto.h"
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/constants.h"
|
||||
|
@ -189,7 +191,7 @@ void GetCachedFuncGraph(const ResourcePtr &resource) {
|
|||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
|
||||
FuncGraphPtr fg = LoadMindIR(realpath.value());
|
||||
FuncGraphPtr fg = mindspore::LoadMindIR(realpath.value());
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
|
||||
}
|
||||
|
@ -1185,7 +1187,10 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
|
|||
#endif
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); }
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
return mindspore::LoadMindIR(file_name, false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode);
|
||||
}
|
||||
|
||||
void ReleaseGeTsd() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
@ -1282,5 +1287,24 @@ void ClearResAtexit() {
|
|||
parse::CleanDataClassToClassMap();
|
||||
trace::ClearTraceStack();
|
||||
}
|
||||
|
||||
py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const size_t key_len, std::string enc_mode) {
|
||||
size_t encrypt_len;
|
||||
auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
|
||||
reinterpret_cast<Byte *>(key), key_len, enc_mode);
|
||||
auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
|
||||
return py_encrypt_data;
|
||||
}
|
||||
|
||||
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_len, std::string dec_mode) {
|
||||
size_t decrypt_len;
|
||||
|
||||
auto decrypt_data =
|
||||
mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
|
||||
auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
|
||||
return py_decrypt_data;
|
||||
}
|
||||
|
||||
bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -143,7 +143,7 @@ void ClearResAtexit();
|
|||
void ReleaseGeTsd();
|
||||
|
||||
void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name, char *dec_key, const size_t key_len, const std::string &dec_mode);
|
||||
|
||||
// init and exec dataset sub graph
|
||||
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
|
||||
|
@ -157,6 +157,10 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
|
|||
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list);
|
||||
|
||||
py::bytes PyEncrypt(char *plain_data, const size_t plain_len, char *key, const size_t key_len, std::string enc_mode);
|
||||
py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const size_t key_len, std::string dec_mode);
|
||||
bool PyIsCipherFile(const std::string &file_path);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -48,3 +48,7 @@ endif()
|
|||
if(USE_GLOG)
|
||||
target_link_libraries(mindspore_core PRIVATE mindspore::glog)
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(mindspore_core PRIVATE mindspore::crypto -pthread)
|
||||
endif()
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "load_mindir/load_model.h"
|
||||
#include "load_mindir/anf_model_parser.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
#include "utils/crypto.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -121,7 +122,55 @@ bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
|
|||
|
||||
int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) {
|
||||
bool ParseModelProto(mind_ir::ModelProto *model, std::string path, const unsigned char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
if (dec_key != nullptr) {
|
||||
size_t plain_len;
|
||||
auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
|
||||
if (plain_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
|
||||
return false;
|
||||
}
|
||||
if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), plain_len)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
std::fstream input_graph(path, std::ios::in | std::ios::binary);
|
||||
if (!input_graph || !model->ParseFromIstream(&input_graph)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigned char *dec_key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
if (dec_key != nullptr) {
|
||||
size_t plain_len;
|
||||
auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
|
||||
if (plain_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
|
||||
return false;
|
||||
}
|
||||
if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), plain_len)) {
|
||||
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
|
||||
"dec_key or dec_mode";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
std::fstream input_param(path, std::ios::in | std::ios::binary);
|
||||
if (!input_param || !graph->ParseFromIstream(&input_param)) {
|
||||
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
const char *file_path = reinterpret_cast<const char *>(file_name.c_str());
|
||||
char abs_path_buff[PATH_MAX];
|
||||
char abs_path[PATH_MAX];
|
||||
|
@ -136,18 +185,10 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
}
|
||||
#endif
|
||||
// Read graph
|
||||
std::fstream input_graph(abs_path_buff, std::ios::in | std::ios::binary);
|
||||
mind_ir::ModelProto origin_model;
|
||||
|
||||
if (!input_graph) {
|
||||
MS_LOG(ERROR) << "Failed to open file: " << file_name;
|
||||
if (!ParseModelProto(&origin_model, std::string(abs_path_buff), dec_key, key_len, dec_mode)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!origin_model.ParseFromIstream(&input_graph)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Load parameter into graph
|
||||
if (endsWith(abs_path_buff, "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
|
||||
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
|
||||
|
@ -171,10 +212,8 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
int file_size = files.size();
|
||||
mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
|
||||
for (auto file_index = 0; file_index < file_size; file_index++) {
|
||||
std::fstream input_param(files[file_index], std::ios::in | std::ios::binary);
|
||||
mind_ir::GraphProto param_graph;
|
||||
if (!input_param || !param_graph.ParseFromIstream(&input_param)) {
|
||||
MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
|
||||
if (!ParseGraphProto(¶m_graph, files[file_index], dec_key, key_len, dec_mode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@
|
|||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false);
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false,
|
||||
const unsigned char *dec_key = nullptr, const size_t key_len = 0,
|
||||
const std::string &dec_mode = std::string("AES-GCM"));
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "utils/crypto.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#if not defined(_WIN32)
|
||||
#include <openssl/aes.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rand.h>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
void intToByte(Byte *byte, const int32_t &n) {
|
||||
memset(byte, 0, sizeof(Byte) * 4);
|
||||
byte[0] = (Byte)(0xFF & n);
|
||||
byte[1] = (Byte)((0xFF00 & n) >> 8);
|
||||
byte[2] = (Byte)((0xFF0000 & n) >> 16);
|
||||
byte[3] = (Byte)((0xFF000000 & n) >> 24);
|
||||
}
|
||||
|
||||
int32_t ByteToint(const Byte *byteArray) {
|
||||
int32_t res = byteArray[0] & 0xFF;
|
||||
res |= ((byteArray[1] << 8) & 0xFF00);
|
||||
res |= ((byteArray[2] << 16) & 0xFF0000);
|
||||
res += ((byteArray[3] << 24) & 0xFF000000);
|
||||
return res;
|
||||
}
|
||||
|
||||
bool IsCipherFile(const std::string &file_path) {
|
||||
std::ifstream fid(file_path, std::ios::in | std::ios::binary);
|
||||
if (!fid) {
|
||||
MS_LOG(ERROR) << "Failed to open file " << file_path;
|
||||
return false;
|
||||
}
|
||||
std::vector<char> int_buf(4);
|
||||
fid.read(int_buf.data(), sizeof(int32_t));
|
||||
fid.close();
|
||||
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
return flag == MAGIC_NUM;
|
||||
}
|
||||
|
||||
bool IsCipherFile(const Byte *model_data) {
|
||||
std::vector<char> int_buf(4);
|
||||
memcpy(int_buf.data(), model_data, 4);
|
||||
auto flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
return flag == MAGIC_NUM;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
|
||||
}
|
||||
#else
|
||||
|
||||
bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, std::vector<Byte> *iv,
|
||||
std::vector<Byte> *cipher_data) {
|
||||
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||
Byte buf[4];
|
||||
memcpy(buf, encrypt_data, 4);
|
||||
|
||||
auto iv_len = ByteToint(buf);
|
||||
memcpy(buf, encrypt_data + iv_len + 4, 4);
|
||||
|
||||
auto cipher_len = ByteToint(buf);
|
||||
if (iv_len <= 0 || cipher_len <= 0 || iv_len + cipher_len + 8 != encrypt_len) {
|
||||
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||
return false;
|
||||
}
|
||||
(*iv).resize(iv_len);
|
||||
memcpy((*iv).data(), encrypt_data + 4, iv_len);
|
||||
|
||||
(*cipher_data).resize(cipher_len);
|
||||
memcpy((*cipher_data).data(), encrypt_data + iv_len + 8, cipher_len);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
|
||||
std::smatch results;
|
||||
std::regex re("([A-Z]{3})-([A-Z]{3})");
|
||||
if (!std::regex_match(mode.c_str(), re)) {
|
||||
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
|
||||
return false;
|
||||
}
|
||||
std::regex_search(mode, results, re);
|
||||
*alg_mode = results[1];
|
||||
*work_mode = results[2];
|
||||
return true;
|
||||
}
|
||||
|
||||
EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
|
||||
int flag) {
|
||||
int ret = 0;
|
||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||
if (work_mode != "GCM" && work_mode != "CBC") {
|
||||
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const EVP_CIPHER *(*funcPtr)() = nullptr;
|
||||
if (work_mode == "GCM") {
|
||||
switch (key_len) {
|
||||
case 16:
|
||||
funcPtr = EVP_aes_128_gcm;
|
||||
break;
|
||||
case 24:
|
||||
funcPtr = EVP_aes_192_gcm;
|
||||
break;
|
||||
case 32:
|
||||
funcPtr = EVP_aes_256_gcm;
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||
}
|
||||
} else if (work_mode == "CBC") {
|
||||
switch (key_len) {
|
||||
case 16:
|
||||
funcPtr = EVP_aes_128_cbc;
|
||||
break;
|
||||
case 24:
|
||||
funcPtr = EVP_aes_192_cbc;
|
||||
break;
|
||||
case 32:
|
||||
funcPtr = EVP_aes_256_cbc;
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||
}
|
||||
}
|
||||
|
||||
if (flag == 0) {
|
||||
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
} else if (flag == 1) {
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
}
|
||||
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
bool _BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, Byte *plain_data, const size_t plain_len,
|
||||
const Byte *key, const int32_t key_len, const std::string &enc_mode) {
|
||||
int32_t cipher_len = 0;
|
||||
|
||||
int32_t iv_len = AES_BLOCK_SIZE;
|
||||
Byte *iv = new Byte[iv_len];
|
||||
RAND_bytes(iv, sizeof(Byte) * iv_len);
|
||||
|
||||
Byte *iv_cpy = new Byte[16];
|
||||
memcpy(iv_cpy, iv, 16);
|
||||
|
||||
int32_t flen = 0;
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Byte> cipher_data(plain_len + 16);
|
||||
int ret = EVP_EncryptUpdate(ctx, cipher_data.data(), &cipher_len, plain_data, plain_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
EVP_EncryptFinal_ex(ctx, cipher_data.data() + cipher_len, &flen);
|
||||
cipher_len += flen;
|
||||
}
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
|
||||
size_t cur = 0;
|
||||
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
|
||||
|
||||
std::vector<Byte> byte_buf(4);
|
||||
intToByte(byte_buf.data(), *encrypt_data_len);
|
||||
memcpy(encrypt_data + cur, byte_buf.data(), 4);
|
||||
cur += 4;
|
||||
|
||||
intToByte(byte_buf.data(), iv_len);
|
||||
memcpy(encrypt_data + cur, byte_buf.data(), 4);
|
||||
cur += 4;
|
||||
|
||||
memcpy(encrypt_data + cur, iv_cpy, iv_len);
|
||||
cur += iv_len;
|
||||
|
||||
intToByte(byte_buf.data(), cipher_len);
|
||||
memcpy(encrypt_data + cur, byte_buf.data(), 4);
|
||||
cur += 4;
|
||||
|
||||
memcpy(encrypt_data + cur, cipher_data.data(), cipher_len);
|
||||
*encrypt_data_len += 4;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool _BlockDecrypt(Byte *plain_data, int32_t *plain_len, Byte *encrypt_data, const size_t encrypt_len, const Byte *key,
|
||||
const int32_t key_len, const std::string &dec_mode) {
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
|
||||
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Byte> iv;
|
||||
std::vector<Byte> cipher_data;
|
||||
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int ret = 0;
|
||||
int mlen = 0;
|
||||
auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv.data(), 1);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
}
|
||||
ret = EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), cipher_data.size());
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
*plain_len += mlen;
|
||||
}
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
size_t cur_pos = 0;
|
||||
size_t block_enc_len = 0;
|
||||
size_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
|
||||
size_t block_enc_buf_len = MAX_BLOCK_SIZE + 100;
|
||||
std::vector<Byte> int_buf(4);
|
||||
std::vector<Byte> block_buf(MAX_BLOCK_SIZE);
|
||||
std::vector<Byte> block_enc_buf(block_enc_buf_len);
|
||||
auto encrypt_data = std::make_unique<Byte[]>(encrypt_buf_len);
|
||||
|
||||
*encrypt_len = 0;
|
||||
while (cur_pos < plain_len) {
|
||||
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - cur_pos);
|
||||
memcpy(block_buf.data(), plain_data + cur_pos, cur_block_size);
|
||||
if (!_BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf.data(), cur_block_size, key, key_len,
|
||||
enc_mode)) {
|
||||
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||
}
|
||||
intToByte(int_buf.data(), MAGIC_NUM);
|
||||
memcpy(encrypt_data.get() + *encrypt_len, int_buf.data(), sizeof(int32_t));
|
||||
*encrypt_len += sizeof(int32_t);
|
||||
memcpy(encrypt_data.get() + *encrypt_len, block_enc_buf.data(), block_enc_len);
|
||||
*encrypt_len += block_enc_len;
|
||||
cur_pos += cur_block_size;
|
||||
}
|
||||
return encrypt_data;
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
|
||||
if (!fid) {
|
||||
MS_EXCEPTION(ValueError) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file.";
|
||||
}
|
||||
fid.seekg(0, std::ios_base::end);
|
||||
size_t file_size = fid.tellg();
|
||||
fid.clear();
|
||||
fid.seekg(0);
|
||||
|
||||
std::vector<char> block_buf(MAX_BLOCK_SIZE * 2);
|
||||
std::vector<char> int_buf(4);
|
||||
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE * 2);
|
||||
auto decrypt_data = std::make_unique<Byte[]>(file_size);
|
||||
int32_t decrypt_block_len;
|
||||
|
||||
*decrypt_len = 0;
|
||||
while (static_cast<size_t>(fid.tellg()) < file_size) {
|
||||
fid.read(int_buf.data(), sizeof(int32_t));
|
||||
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
if (cipher_flag != MAGIC_NUM) {
|
||||
MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
|
||||
<< "\" is not an encrypted file and cannot be decrypted";
|
||||
}
|
||||
fid.read(int_buf.data(), sizeof(int32_t));
|
||||
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
fid.read(block_buf.data(), sizeof(char) * block_size);
|
||||
if (!(_BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||
block_size, key, key_len, dec_mode))) {
|
||||
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
}
|
||||
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), decrypt_block_len);
|
||||
*decrypt_len += decrypt_block_len;
|
||||
}
|
||||
fid.close();
|
||||
return decrypt_data;
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode) {
|
||||
std::vector<char> block_buf(MAX_BLOCK_SIZE * 2);
|
||||
std::vector<char> int_buf(4);
|
||||
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE * 2);
|
||||
auto decrypt_data = std::make_unique<Byte[]>(data_size);
|
||||
int32_t decrypt_block_len;
|
||||
|
||||
size_t cur_pos = 0;
|
||||
*decrypt_len = 0;
|
||||
while (cur_pos < data_size) {
|
||||
memcpy(int_buf.data(), model_data + cur_pos, 4);
|
||||
cur_pos += 4;
|
||||
int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
if (cipher_flag != MAGIC_NUM) {
|
||||
MS_EXCEPTION(ValueError) << "model_data is not encrypted and therefore cannot be decrypted.";
|
||||
}
|
||||
memcpy(int_buf.data(), model_data + cur_pos, 4);
|
||||
cur_pos += 4;
|
||||
|
||||
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf.data()));
|
||||
memcpy(block_buf.data(), model_data + cur_pos, block_size);
|
||||
cur_pos += block_size;
|
||||
if (!(_BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||
block_size, key, key_len, dec_mode))) {
|
||||
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
}
|
||||
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), decrypt_block_len);
|
||||
*decrypt_len += decrypt_block_len;
|
||||
}
|
||||
return decrypt_data;
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_UTILS_CRYPTO_H
|
||||
#define MINDSPORE_CORE_UTILS_CRYPTO_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
typedef unsigned char Byte;
|
||||
namespace mindspore {
|
||||
const size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||
|
||||
std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, Byte *plain_data, const size_t plain_len, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode);
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode);
|
||||
std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, const size_t data_size, const Byte *key,
|
||||
const size_t key_len, const std::string &dec_mode);
|
||||
bool IsCipherFile(const std::string &file_path);
|
||||
bool IsCipherFile(const Byte *model_data);
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -166,6 +166,7 @@ include(${TOP_DIR}/cmake/utils.cmake)
|
|||
include(${TOP_DIR}/cmake/dependency_utils.cmake)
|
||||
include(${TOP_DIR}/cmake/dependency_securec.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/flatbuffers.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||
include(${TOP_DIR}/cmake/external_libs/opencl.cmake)
|
||||
endif()
|
||||
|
|
|
@ -46,11 +46,23 @@ Status Serialization::Load(const void *model_data, size_t data_size, ModelType m
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
Status Serialization::LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return kMEFailed;
|
||||
|
|
|
@ -53,6 +53,13 @@ Flags::Flags() {
|
|||
"whether the model is going to be trained on device. "
|
||||
"true | false",
|
||||
"false");
|
||||
AddFlag(&Flags::dec_key, "decryptKey",
|
||||
"The key used to decrypt the file, expressed in hexadecimal characters. Only valid when fmkIn is 'MINDIR'",
|
||||
"");
|
||||
AddFlag(&Flags::dec_mode, "decryptMode",
|
||||
"Decryption method for the MindIR file. Only valid when dec_key is set."
|
||||
"AES-GCM | AES-CBC",
|
||||
"AES-GCM");
|
||||
}
|
||||
|
||||
int Flags::InitInputOutputDataType() {
|
||||
|
|
|
@ -99,6 +99,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
bool trainModel = false;
|
||||
std::vector<std::string> pluginsPath;
|
||||
bool disableFusion = false;
|
||||
std::string dec_key = "";
|
||||
std::string dec_mode = "AES-GCM";
|
||||
};
|
||||
|
||||
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "tools/converter/import/mindspore_importer.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/import/primitive_adjust.h"
|
||||
#include "tools/converter/import/mindir_adjust.h"
|
||||
|
@ -137,9 +138,53 @@ STATUS MindsporeImporter::HardCodeMindir(const CNodePtr &conv_node, const FuncGr
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
size_t MindsporeImporter::Hex2ByteArray(std::string hex_str, unsigned char *byte_array, size_t max_len) {
|
||||
std::regex r("[0-9a-fA-F]+");
|
||||
if (!std::regex_match(hex_str, r)) {
|
||||
MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
|
||||
return 0;
|
||||
}
|
||||
if (hex_str.size() % 2 == 1) {
|
||||
hex_str.push_back('0');
|
||||
}
|
||||
size_t byte_len = hex_str.size() / 2;
|
||||
if (byte_len > max_len) {
|
||||
MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: 64";
|
||||
return 0;
|
||||
}
|
||||
for (size_t i = 0; i < byte_len; ++i) {
|
||||
size_t p = i * 2;
|
||||
if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
|
||||
byte_array[i] = hex_str[p] - 'a' + 10;
|
||||
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||
byte_array[i] = hex_str[p] - 'A' + 10;
|
||||
} else {
|
||||
byte_array[i] = hex_str[p] - '0';
|
||||
}
|
||||
if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
|
||||
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - 'a' + 10);
|
||||
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - 'A' + 10);
|
||||
} else {
|
||||
byte_array[i] = (byte_array[i] << 4) | (hex_str[p + 1] - '0');
|
||||
}
|
||||
}
|
||||
return byte_len;
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||
quant_type_ = flag.quantType;
|
||||
auto func_graph = LoadMindIR(flag.modelFile);
|
||||
FuncGraphPtr func_graph;
|
||||
if (flag.dec_key.size() != 0) {
|
||||
unsigned char key[32];
|
||||
const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32);
|
||||
if (key_len == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = LoadMindIR(flag.modelFile, false, key, key_len, flag.dec_mode);
|
||||
} else {
|
||||
func_graph = LoadMindIR(flag.modelFile);
|
||||
}
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||
return nullptr;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_IMPORT_MINDSPORE_IMPORTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_IMPORT_MINDSPORE_IMPORTER_H_
|
||||
|
||||
#include <string>
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
|
||||
|
@ -32,6 +33,7 @@ class MindsporeImporter {
|
|||
STATUS WeightFormatTransform(const FuncGraphPtr &graph);
|
||||
STATUS HardCodeMindir(const CNodePtr &conv_node, const FuncGraphPtr &graph);
|
||||
QuantType quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
size_t Hex2ByteArray(std::string hex_str, unsigned char *byte_array, size_t max_len);
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -304,7 +304,7 @@ def _check_append_dict(append_dict):
|
|||
return append_dict
|
||||
|
||||
|
||||
def load(file_name):
|
||||
def load(file_name, **kwargs):
|
||||
"""
|
||||
Load MindIR.
|
||||
|
||||
|
@ -314,6 +314,11 @@ def load(file_name):
|
|||
Args:
|
||||
file_name (str): MindIR file name.
|
||||
|
||||
kwargs (dict): Configuration options dictionary.
|
||||
|
||||
- dec_key: Byte type key used for decryption. Tha valid length is 16, 24, or 32.
|
||||
- dec_mode: Specifies the decryption mode, take effect when dec_key is set. Option: 'AES-GCM' | 'AES-CBC'.
|
||||
Default: 'AES-GCM'.
|
||||
Returns:
|
||||
Object, a compiled graph that can executed by `GraphCell`.
|
||||
|
||||
|
@ -341,8 +346,19 @@ def load(file_name):
|
|||
raise ValueError("The MindIR should end with mindir, please input the correct file name.")
|
||||
|
||||
logger.info("Execute the process of loading mindir.")
|
||||
graph = load_mindir(file_name)
|
||||
if 'dec_key' in kwargs.keys():
|
||||
dec_key = Validator.check_isinstance('dec_key', kwargs['dec_key'], bytes)
|
||||
dec_mode = 'AES-GCM'
|
||||
if 'dec_mode' in kwargs.keys():
|
||||
dec_mode = Validator.check_isinstance('dec_mode', kwargs['dec_mode'], str)
|
||||
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode)
|
||||
else:
|
||||
graph = load_mindir(file_name)
|
||||
|
||||
if graph is None:
|
||||
if _is_cipher_file(file_name):
|
||||
raise RuntimeError("Load MindIR failed. The file may be encrypted, please pass in the "
|
||||
"correct dec_key and dec_mode.")
|
||||
raise RuntimeError("Load MindIR failed.")
|
||||
return graph
|
||||
|
||||
|
@ -392,7 +408,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
except BaseException as e:
|
||||
if _is_cipher_file(ckpt_file_name):
|
||||
logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the "
|
||||
"dec_key.", ckpt_file_name)
|
||||
"correct dec_key.", ckpt_file_name)
|
||||
else:
|
||||
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \
|
||||
ckpt_file_name)
|
||||
|
@ -686,6 +702,9 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
Default: 127.5.
|
||||
- std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network.
|
||||
Default: 127.5.
|
||||
- enc_key: Byte type key used for encryption. Tha valid length is 16, 24, or 32.
|
||||
- enc_mode: Specifies the encryption mode, take effect when enc_key is set. Option: 'AES-GCM' | 'AES-CBC'.
|
||||
Default: 'AES-GCM'.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -702,10 +721,20 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||
_export(net, file_name, file_format, *inputs)
|
||||
if 'enc_key' in kwargs.keys():
|
||||
if file_format != 'MINDIR':
|
||||
raise ValueError(f"enc_key can be passed in only when file_format=='MINDIR', but got {file_format}")
|
||||
|
||||
enc_key = Validator.check_isinstance('enc_key', kwargs['enc_key'], bytes)
|
||||
enc_mode = 'AES-GCM'
|
||||
if 'enc_mode' in kwargs.keys():
|
||||
enc_mode = Validator.check_isinstance('enc_mode', kwargs['enc_mode'], str)
|
||||
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode)
|
||||
else:
|
||||
_export(net, file_name, file_format, *inputs)
|
||||
|
||||
|
||||
def _export(net, file_name, file_format, *inputs):
|
||||
def _export(net, file_name, file_format, *inputs, **kwargs):
|
||||
"""
|
||||
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
||||
"""
|
||||
|
@ -744,13 +773,13 @@ def _export(net, file_name, file_format, *inputs):
|
|||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(onnx_stream)
|
||||
elif file_format == 'MINDIR':
|
||||
_save_mindir(net, file_name, *inputs)
|
||||
_save_mindir(net, file_name, *inputs, **kwargs)
|
||||
|
||||
if is_dump_onnx_in_training:
|
||||
net.set_train(mode=True)
|
||||
|
||||
|
||||
def _save_mindir(net, file_name, *inputs):
|
||||
def _save_mindir(net, file_name, *inputs, **kwargs):
|
||||
"""Save MindIR format file."""
|
||||
model = mindir_model()
|
||||
|
||||
|
@ -764,7 +793,7 @@ def _save_mindir(net, file_name, *inputs):
|
|||
model.ParseFromString(mindir_stream)
|
||||
|
||||
save_together = _mindir_save_together(net_dict, model)
|
||||
|
||||
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
||||
if save_together:
|
||||
for param_proto in model.graph.parameter:
|
||||
param_name = param_proto.name[param_proto.name.find(":")+1:]
|
||||
|
@ -781,7 +810,11 @@ def _save_mindir(net, file_name, *inputs):
|
|||
os.makedirs(dirname, exist_ok=True)
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(model.SerializeToString())
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(model_string)
|
||||
else:
|
||||
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
||||
# save parameter
|
||||
|
@ -815,7 +848,11 @@ def _save_mindir(net, file_name, *inputs):
|
|||
data_file_name = data_path + "/" + "data_" + str(index)
|
||||
with open(data_file_name, "ab") as f:
|
||||
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(graphproto.SerializeToString())
|
||||
graph_string = graphproto.SerializeToString()
|
||||
if is_encrypt():
|
||||
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'],
|
||||
len(kwargs['enc_key']), kwargs['enc_mode'])
|
||||
f.write(graph_string)
|
||||
index += 1
|
||||
data_size = 0
|
||||
del graphproto.parameter[:]
|
||||
|
@ -824,14 +861,22 @@ def _save_mindir(net, file_name, *inputs):
|
|||
data_file_name = data_path + "/" + "data_" + str(index)
|
||||
with open(data_file_name, "ab") as f:
|
||||
os.chmod(data_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(graphproto.SerializeToString())
|
||||
graph_string = graphproto.SerializeToString()
|
||||
if is_encrypt():
|
||||
graph_string = _encrypt(graph_string, len(graph_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(graph_string)
|
||||
|
||||
# save graph
|
||||
del model.graph.parameter[:]
|
||||
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
|
||||
with open(graph_file_name, 'wb') as f:
|
||||
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
|
||||
f.write(model.SerializeToString())
|
||||
model_string = model.SerializeToString()
|
||||
if is_encrypt():
|
||||
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], len(kwargs['enc_key']),
|
||||
kwargs['enc_mode'])
|
||||
f.write(model_string)
|
||||
|
||||
|
||||
def _mindir_save_together(net_dict, model):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import os
|
||||
import platform
|
||||
import stat
|
||||
import secrets
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -254,7 +255,7 @@ def test_checkpoint_save_ckpt_with_encryption():
|
|||
save_checkpoint_seconds=0,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
enc_key=os.urandom(16),
|
||||
enc_key=secrets.token_bytes(16),
|
||||
enc_mode="AES-GCM")
|
||||
ckpt_cb = ModelCheckpoint(config=train_config)
|
||||
cb_params = _InternalCallbackParam()
|
||||
|
|
|
@ -17,6 +17,7 @@ import os
|
|||
import platform
|
||||
import stat
|
||||
import time
|
||||
import secrets
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -32,7 +33,7 @@ from mindspore.nn.optim.momentum import Momentum
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import _CheckpointManager
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||
export, _save_graph
|
||||
export, _save_graph, load
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
|
@ -332,7 +333,7 @@ def test_save_and_load_checkpoint_for_network_with_encryption():
|
|||
|
||||
loss_net = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(loss_net, opt)
|
||||
key = os.urandom(16)
|
||||
key = secrets.token_bytes(16)
|
||||
mode = "AES-GCM"
|
||||
ckpt_path = "./encrypt_ckpt.ckpt"
|
||||
if platform.system().lower() == "windows":
|
||||
|
@ -383,6 +384,16 @@ def test_mindir_export():
|
|||
export(net, input_data, file_name="./me_binary_export", file_format="MINDIR")
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_mindir_export_and_load_with_encryption():
|
||||
net = MYNET()
|
||||
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
|
||||
key = secrets.token_bytes(16)
|
||||
export(net, input_data, file_name="./me_cipher_binary_export.mindir", file_format="MINDIR", enc_key=key)
|
||||
load("./me_cipher_binary_export.mindir", dec_key=key)
|
||||
|
||||
|
||||
|
||||
class PrintNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PrintNet, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue