add encryption to lite

This commit is contained in:
yeyunpeng2020 2022-03-03 14:11:40 +08:00
parent c337e64241
commit f670a635f0
32 changed files with 1262 additions and 167 deletions

View File

@ -12,17 +12,66 @@ else()
set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl) set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl)
endif() endif()
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE) if(BUILD_LITE)
mindspore_add_pkg(openssl if(PLATFORM_ARM64 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
VER 1.1.1k set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
LIBS ssl crypto set(PATH
URL ${REQ_URL} ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
MD5 ${MD5} ${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
CONFIGURE_COMMAND ./config no-zlib no-shared $ENV{PATH})
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch mindspore_add_pkg(openssl
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch VER 1.1.1k
) LIBS ssl crypto
include_directories(${openssl_INC}) URL ${REQ_URL}
add_library(mindspore::ssl ALIAS openssl::ssl) MD5 ${MD5}
add_library(mindspore::crypto ALIAS openssl::crypto) CONFIGURE_COMMAND ./Configure android-arm64 -D__ANDROID_API__=29 no-zlib
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
)
elseif(PLATFORM_ARM32 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
set(PATH
${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
$ENV{PATH})
mindspore_add_pkg(openssl
VER 1.1.1k
LIBS ssl crypto
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./Configure android-arm -D__ANDROID_API__=29 no-zlib
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
)
elseif(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
mindspore_add_pkg(openssl
VER 1.1.1k
LIBS ssl crypto
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./config no-zlib no-shared
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
)
else()
MESSAGE(FATAL_ERROR "openssl does not support compilation for the current environment.")
endif()
include_directories(${openssl_INC})
add_library(mindspore::ssl ALIAS openssl::ssl)
add_library(mindspore::crypto ALIAS openssl::crypto)
else()
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
mindspore_add_pkg(openssl
VER 1.1.1k
LIBS ssl crypto
URL ${REQ_URL}
MD5 ${MD5}
CONFIGURE_COMMAND ./config no-zlib no-shared
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
)
include_directories(${openssl_INC})
add_library(mindspore::ssl ALIAS openssl::ssl)
add_library(mindspore::crypto ALIAS openssl::crypto)
endif()
endif() endif()

View File

@ -173,7 +173,7 @@ class MS_API Model {
/// \return Status of operation /// \return Status of operation
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights); Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
/// \brief Obtains optimizer params tensors of the model. /// \brief Obtains optimizer params tensors of the model.
/// ///
/// \return The vector that includes all params tensors. /// \return The vector that includes all params tensors.
std::vector<MSTensor> GetOptimizerParams() const; std::vector<MSTensor> GetOptimizerParams() const;
@ -256,17 +256,14 @@ class MS_API Model {
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite. /// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
/// ///
/// \param[in] model_data Define the buffer read from a model file. /// \param[in] model_data Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer. /// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite. /// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution. /// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
/// ///
/// \return Status. /// \return Status.
inline Status Build(const void *model_data, size_t data_size, ModelType model_type, Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {}, const std::shared_ptr<Context> &model_context = nullptr);
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
/// ///
@ -274,13 +271,40 @@ class MS_API Model {
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite. /// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution. /// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
/// ///
/// \return Status. /// \return Status.
inline Status Build(const std::string &model_path, ModelType model_type, Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {}, const std::shared_ptr<Context> &model_context = nullptr);
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_data Define the buffer read from a model file.
/// \param[in] data_size Define bytes number of model buffer.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::string &cropto_lib_path);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
/// \param[in] cropto_lib_path Define the openssl library path.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path);
private: private:
friend class Serialization; friend class Serialization;
@ -291,11 +315,10 @@ class MS_API Model {
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name); std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path); Status LoadConfig(const std::vector<char> &config_path);
Status UpdateConfig(const std::vector<char> &section, const std::pair<std::vector<char>, std::vector<char>> &config); Status UpdateConfig(const std::vector<char> &section, const std::pair<std::vector<char>, std::vector<char>> &config);
Status Build(const void *model_data, size_t data_size, ModelType model_type, Status Build(const std::vector<char> &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode); const std::shared_ptr<Context> &model_context);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context, Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode); const Key &dec_key, const std::string &dec_mode, const std::vector<char> &cropto_lib_path);
std::shared_ptr<ModelImpl> impl_; std::shared_ptr<ModelImpl> impl_;
}; };
@ -321,14 +344,15 @@ Status Model::UpdateConfig(const std::string &section, const std::pair<std::stri
return UpdateConfig(StringToChar(section), config_pair); return UpdateConfig(StringToChar(section), config_pair);
} }
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, inline Status Model::Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) { const std::shared_ptr<Context> &model_context, const Key &dec_key,
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode)); const std::string &dec_mode, const std::string &cropto_lib_path) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path));
} }
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context, inline Status Model::Build(const std::string &model_path, ModelType model_type,
const Key &dec_key, const std::string &dec_mode) { const std::shared_ptr<Context> &model_context) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode)); return Build(StringToChar(model_path), model_type, model_context);
} }
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H #endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -52,14 +52,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
return impl_->Build(); return impl_->Build();
} }
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &, Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::vector<char> &) { const std::string &, const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed; return kMCFailed;
} }
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &, Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &) {
const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed; return kMCFailed;
} }

View File

@ -120,7 +120,7 @@ bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work
} }
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv, EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
bool is_encrypt) { int iv_len, bool is_encrypt) {
constexpr int32_t key_length_16 = 16; constexpr int32_t key_length_16 = 16;
constexpr int32_t key_length_24 = 24; constexpr int32_t key_length_24 = 24;
constexpr int32_t key_length_32 = 32; constexpr int32_t key_length_32 = 32;
@ -163,8 +163,35 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
int32_t ret = 0; int32_t ret = 0;
auto ctx = EVP_CIPHER_CTX_new(); auto ctx = EVP_CIPHER_CTX_new();
if (is_encrypt) { if (is_encrypt) {
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
} else { } else {
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
} }
@ -183,7 +210,7 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
} }
bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key, bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key,
int32_t key_len, const std::string &enc_mode) { int32_t key_len, const std::string &enc_mode, unsigned char *tag) {
size_t encrypt_data_buf_len = *encrypt_data_len; size_t encrypt_data_buf_len = *encrypt_data_len;
int32_t cipher_len = 0; int32_t cipher_len = 0;
int32_t iv_len = AES_BLOCK_SIZE; int32_t iv_len = AES_BLOCK_SIZE;
@ -201,7 +228,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
return false; return false;
} }
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), true); auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()), true);
if (ctx == nullptr) { if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false; return false;
@ -214,15 +241,19 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
return false; return false;
} }
if (work_mode == "CBC") { int32_t flen = 0;
int32_t flen = 0; ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen); if (ret_evp != 1) {
if (ret_evp != 1) { MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed"; return false;
return false;
}
cipher_len += flen;
} }
cipher_len += flen;
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, Byte16, tag) != 1) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
return false;
}
EVP_CIPHER_CTX_free(ctx); EVP_CIPHER_CTX_free(ctx);
size_t offset = 0; size_t offset = 0;
@ -266,7 +297,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
} }
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key, bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
int32_t key_len, const std::string &dec_mode) { int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
std::string alg_mode; std::string alg_mode;
std::string work_mode; std::string work_mode;
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
@ -277,7 +308,7 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) { if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
return false; return false;
} }
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), false); auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), iv.size(), false);
if (ctx == nullptr) { if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false; return false;
@ -288,15 +319,20 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
return false; return false;
} }
if (work_mode == "CBC") {
int32_t mlen = 0; if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen); MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
if (ret != 1) { return false;
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
} }
int32_t mlen = 0;
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
EVP_CIPHER_CTX_free(ctx); EVP_CIPHER_CTX_free(ctx);
return true; return true;
} }
@ -319,7 +355,9 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
size_t block_enc_len = block_enc_buf.size(); size_t block_enc_len = block_enc_buf.size();
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset); size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset);
block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size); block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size);
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode)) { unsigned char tag[Byte16];
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode,
tag)) {
MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
return nullptr; return nullptr;
} }
@ -332,6 +370,13 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
} }
*encrypt_len += sizeof(int32_t); *encrypt_len += sizeof(int32_t);
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, tag, Byte16);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
}
*encrypt_len += Byte16;
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN);
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len); ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len);
if (ret != 0) { if (ret != 0) {
@ -371,6 +416,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted"; MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted";
return nullptr; return nullptr;
} }
unsigned char tag[Byte16];
fid.read(reinterpret_cast<char *>(tag), Byte16);
fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t))); fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size()); auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
if (block_size < 0) { if (block_size < 0) {
@ -379,7 +428,7 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
} }
fid.read(block_buf.data(), static_cast<int64_t>(block_size)); fid.read(block_buf.data(), static_cast<int64_t>(block_size));
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()), if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode))) { static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
return nullptr; return nullptr;
} }
@ -409,6 +458,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
size_t offset = 0; size_t offset = 0;
*decrypt_len = 0; *decrypt_len = 0;
while (offset < data_size) { while (offset < data_size) {
if (offset + sizeof(int32_t) > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
offset += int_buf.size(); offset += int_buf.size();
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size()); auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
@ -416,27 +469,44 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted."; MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
return nullptr; return nullptr;
} }
unsigned char tag[Byte16];
if (offset + Byte16 > data_size) {
MS_LOG(ERROR) << "buffer is invalid.";
return nullptr;
}
auto ret = memcpy_s(tag, Byte16, model_data + offset, Byte16);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
}
offset += Byte16;
if (offset + sizeof(int32_t) > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
offset += int_buf.size(); offset += int_buf.size();
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size()); auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
if (block_size < 0) { if (block_size <= 0) {
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size; MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
return nullptr; return nullptr;
} }
if (offset + block_size > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
block_buf.assign(model_data + offset, model_data + offset + block_size); block_buf.assign(model_data + offset, model_data + offset + block_size);
offset += block_buf.size(); offset += block_buf.size();
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()), if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode))) { block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
return nullptr; return nullptr;
} }
size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN); ret = memcpy_s(decrypt_data.get() + *decrypt_len, data_size, decrypt_block_buf.data(),
auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), static_cast<size_t>(decrypt_block_len));
static_cast<size_t>(decrypt_block_len)); if (ret != EOK) {
if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
} }
*decrypt_len += static_cast<size_t>(decrypt_block_len); *decrypt_len += static_cast<size_t>(decrypt_block_len);
} }
return decrypt_data; return decrypt_data;

View File

@ -26,6 +26,7 @@ namespace mindspore {
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
constexpr size_t Byte16 = 16;
MS_CORE_API std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, MS_CORE_API std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len,
const Byte *key, size_t key_len, const std::string &enc_mode); const Byte *key, size_t key_len, const std::string &enc_mode);

View File

@ -34,7 +34,7 @@ option(MSLITE_ENABLE_V0 "support v0 schema" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off) option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on) option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
option(MSLITE_ENABLE_ACL "enable ACL" off) option(MSLITE_ENABLE_ACL "enable ACL" off)
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on) option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" off)
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off) option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off) option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off) option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
@ -127,7 +127,11 @@ if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT})
set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT}) set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT})
endif() endif()
if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64) OR (PLATFORM_ARM AND ANDROID_NDK_TOOLCHAIN_INCLUDED))
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
else()
set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF)
endif()
endif() endif()
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT}) if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
@ -227,7 +231,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
endif() endif()
set(MSLITE_ENABLE_RUNTIME_GLOG off) set(MSLITE_ENABLE_RUNTIME_GLOG off)
set(MSLITE_ENABLE_RUNTIME_CONVERT off) set(MSLITE_ENABLE_RUNTIME_CONVERT off)
#set for cross - compiling toolchain #set for cross - compiling toolchain
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
@ -540,13 +544,15 @@ if(MSLITE_ENABLE_CONVERTER)
include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${PYTHON_INCLUDE_DIRS})
include(${TOP_DIR}/cmake/external_libs/eigen.cmake) include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake) include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
find_package(Patch)
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
endif() endif()
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
find_package(Patch)
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
add_compile_definitions(ENABLE_OPENSSL)
endif()
if(MSLITE_ENABLE_MINDRT) if(MSLITE_ENABLE_MINDRT)
add_compile_definitions(ENABLE_MINDRT) add_compile_definitions(ENABLE_MINDRT)
endif() endif()
@ -590,7 +596,7 @@ if(NOT PLATFORM_ARM)
endif() endif()
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full"
OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
add_compile_definitions(ENABLE_ANDROID) add_compile_definitions(ENABLE_ANDROID)
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
add_compile_definitions(ENABLE_MD_LITE_X86_64) add_compile_definitions(ENABLE_MD_LITE_X86_64)
@ -605,7 +611,7 @@ endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops)
if(ANDROID_NDK_TOOLCHAIN_INCLUDED) if(ANDROID_NDK_TOOLCHAIN_INCLUDED)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)

View File

@ -206,6 +206,7 @@ build_lite() {
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off"
else else
checkndk checkndk
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on"
@ -237,6 +238,7 @@ build_lite() {
ARM64_COMPILE_CONVERTER=ON ARM64_COMPILE_CONVERTER=ON
else else
checkndk checkndk
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}"
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"

View File

@ -58,19 +58,19 @@ public class Model {
/** /**
* Build model. * Build model.
* *
* @param buffer model buffer. * @param buffer model buffer.
* @param modelType model type. * @param modelType model type.
* @param context model build context. * @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. * @param dec_mode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path.
* @return model build status. * @return model build status.
*/ */
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
String dec_mode) {
if (context == null || buffer == null || dec_key == null || dec_mode == null) { if (context == null || buffer == null || dec_key == null || dec_mode == null) {
return false; return false;
} }
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode); modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
return modelPtr != 0; return modelPtr != 0;
} }
@ -86,7 +86,7 @@ public class Model {
if (context == null || buffer == null) { if (context == null || buffer == null) {
return false; return false;
} }
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, ""); modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", "");
return modelPtr != 0; return modelPtr != 0;
} }
@ -94,18 +94,19 @@ public class Model {
/** /**
* Build model. * Build model.
* *
* @param modelPath model path. * @param modelPath model path.
* @param modelType model type. * @param modelType model type.
* @param context model build context. * @param context model build context.
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. * @param dec_mode define the decryption mode. Options: AES-GCM.
* @param cropto_lib_path define the openssl library path.
* @return model build status. * @return model build status.
*/ */
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) { public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
if (context == null || modelPath == null || dec_key == null || dec_mode == null) { if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
return false; return false;
} }
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode); modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
return modelPtr != 0; return modelPtr != 0;
} }
@ -121,7 +122,7 @@ public class Model {
if (context == null || modelPath == null) { if (context == null || modelPath == null) {
return false; return false;
} }
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, ""); modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", "");
return modelPtr != 0; return modelPtr != 0;
} }
@ -256,8 +257,7 @@ public class Model {
* @param outputTensorNames tensor name used for export inference graph. * @param outputTensorNames tensor name used for export inference graph.
* @return Whether the export is successful. * @return Whether the export is successful.
*/ */
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, List<String> outputTensorNames) {
List<String> outputTensorNames) {
if (fileName == null) { if (fileName == null) {
return false; return false;
} }
@ -355,10 +355,11 @@ public class Model {
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr); private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod); private native long buildByPath(String modelPath, int modelType, long contextPtr,
char[] dec_key, String dec_mod, String cropto_lib_path);
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key, private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr,
String dec_mod); char[] dec_key, String dec_mod, String cropto_lib_path);
private native List<Long> getInputs(long modelPtr); private native List<Long> getInputs(long modelPtr);
@ -380,8 +381,7 @@ public class Model {
private native boolean resize(long modelPtr, long[] inputs, int[][] dims); private native boolean resize(long modelPtr, long[] inputs, int[][] dims);
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, String[] outputTensorNames);
String[] outputTensorNames);
private native List<Long> getFeatureMaps(long modelPtr); private native List<Long> getFeatureMaps(long modelPtr);
@ -389,6 +389,5 @@ public class Model {
private native boolean setLearningRate(long modelPtr, float learning_rate); private native boolean setLearningRate(long modelPtr, float learning_rate);
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, float momentum);
float momentum);
} }

View File

@ -68,7 +68,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
jobject model_buffer, jint model_type, jobject model_buffer, jint model_type,
jlong context_ptr, jcharArray key_str, jlong context_ptr, jcharArray key_str,
jstring dec_mod) { jstring dec_mod, jstring cropto_lib_path) {
if (model_buffer == nullptr) { if (model_buffer == nullptr) {
MS_LOGE("Buffer from java is nullptr"); MS_LOGE("Buffer from java is nullptr");
return reinterpret_cast<jlong>(nullptr); return reinterpret_cast<jlong>(nullptr);
@ -116,7 +116,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
} }
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len}; mindspore::Key dec_key{dec_key_data, key_len};
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
} else { } else {
status = model->Build(model_buf, buffer_len, c_model_type, context); status = model->Build(model_buf, buffer_len, c_model_type, context);
} }
@ -130,7 +131,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path, extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
jint model_type, jlong context_ptr, jint model_type, jlong context_ptr,
jcharArray key_str, jstring dec_mod) { jcharArray key_str, jstring dec_mod,
jstring cropto_lib_path) {
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE); auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
mindspore::ModelType c_model_type; mindspore::ModelType c_model_type;
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) { if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
@ -172,7 +174,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *
} }
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
mindspore::Key dec_key{dec_key_data, key_len}; mindspore::Key dec_key{dec_key_data, key_len};
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod); auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
} else { } else {
status = model->Build(c_model_path, c_model_type, context); status = model->Build(c_model_path, c_model_type, context);
} }

View File

@ -131,6 +131,14 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc
) )
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/common/decrypt.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
)
endif()
if(MSLITE_ENABLE_SERVER_INFERENCE) if(MSLITE_ENABLE_SERVER_INFERENCE)
set(LITE_SRC set(LITE_SRC
${LITE_SRC} ${LITE_SRC}
@ -272,8 +280,7 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
${TOOLS_DIR}/common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
${TOOLS_DIR}/common/meta_graph_serializer.cc
${TOOLS_DIR}/converter/optimizer.cc ${TOOLS_DIR}/converter/optimizer.cc
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc

View File

@ -0,0 +1,323 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/decrypt.h"
#ifdef ENABLE_OPENSSL
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#include <regex>
#include <vector>
#include <fstream>
#include <algorithm>
#include "src/common/dynamic_library_loader.h"
#include "src/common/file_utils.h"
#endif
#include "src/common/log_adapter.h"
#include "src/common/log_util.h"
#ifndef SECUREC_MEM_MAX_LEN
#define SECUREC_MEM_MAX_LEN 0x7fffffffUL
#endif
namespace mindspore::lite {
#ifndef ENABLE_OPENSSL
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *, const Byte *, const size_t, const Byte *,
const size_t, const std::string &) {
MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
"when the OPENSSL compilation option is enabled.";
return nullptr;
}
#else
namespace {
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
constexpr size_t Byte16 = 16; // Byte16
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
DynamicLibraryLoader loader;
} // namespace
int32_t ByteToInt(const Byte *byteArray, size_t length) {
if (byteArray == nullptr) {
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
return -1;
}
if (length < sizeof(int32_t)) {
MS_LOG(ERROR) << "Length of byteArray is " << length << ", less than sizeof(int32_t): 4.";
return -1;
}
return *(reinterpret_cast<const int32_t *>(byteArray));
}
bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector<Byte> *iv,
std::vector<Byte> *cipher_data) {
if (encrypt_data == nullptr || iv == nullptr || cipher_data == nullptr) {
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
return false;
}
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
std::vector<Byte> int_buf(sizeof(int32_t));
if (sizeof(int32_t) > encrypt_len) {
MS_LOG(ERROR) << "assign len is invalid.";
return false;
}
int_buf.assign(encrypt_data, encrypt_data + sizeof(int32_t));
auto iv_len = ByteToInt(int_buf.data(), int_buf.size());
if (iv_len <= 0 || iv_len + sizeof(int32_t) + sizeof(int32_t) > encrypt_len) {
MS_LOG(ERROR) << "assign len is invalid.";
return false;
}
int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t));
auto cipher_len = ByteToInt(int_buf.data(), int_buf.size());
if (((static_cast<size_t>(iv_len) + sizeof(int32_t) + static_cast<size_t>(cipher_len) + sizeof(int32_t)) !=
encrypt_len)) {
MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false;
}
(*iv).assign(encrypt_data + sizeof(int32_t), encrypt_data + sizeof(int32_t) + iv_len);
if (cipher_len <= 0 || sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len > encrypt_len) {
MS_LOG(ERROR) << "assign len is invalid.";
return false;
}
(*cipher_data)
.assign(encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t),
encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len);
return true;
}
bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) {
if (alg_mode == nullptr || work_mode == nullptr) {
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
return false;
}
std::smatch results;
std::regex re("([A-Z]{3})-([A-Z]{3})");
if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) {
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
return false;
}
const size_t index_1 = 1;
const size_t index_2 = 2;
*alg_mode = results[index_1];
*work_mode = results[index_2];
return true;
}
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
int iv_len) {
constexpr int32_t key_length_16 = 16;
constexpr int32_t key_length_24 = 24;
constexpr int32_t key_length_32 = 32;
const EVP_CIPHER *(*funcPtr)() = nullptr;
if (work_mode == "GCM") {
switch (key_len) {
case key_length_16:
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_128_gcm");
break;
case key_length_24:
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_192_gcm");
break;
case key_length_32:
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_256_gcm");
break;
default:
MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
return nullptr;
}
} else {
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
return nullptr;
}
int32_t ret = 0;
EVP_CIPHER_CTX *(*EVP_CIPHER_CTX_new)() = (EVP_CIPHER_CTX * (*)()) loader.GetFunc("EVP_CIPHER_CTX_new");
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == nullptr) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new failed";
return nullptr;
}
int (*EVP_DecryptInit_ex)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
const unsigned char *) =
(int (*)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
const unsigned char *))loader.GetFunc("EVP_DecryptInit_ex");
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
(int (*)(EVP_CIPHER_CTX * ctx, int type, int arg, void *ptr)) loader.GetFunc("EVP_CIPHER_CTX_ctrl");
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
EVP_CIPHER_CTX_free(ctx);
return nullptr;
}
return ctx;
}
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
if (plain_data == nullptr || plain_len == nullptr || encrypt_data == nullptr || key == nullptr) {
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
return false;
}
std::string alg_mode;
std::string work_mode;
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
return false;
}
std::vector<Byte> iv;
std::vector<Byte> cipher_data;
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
return false;
}
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()));
if (ctx == nullptr) {
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
return false;
}
int (*EVP_DecryptUpdate)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int) =
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int))loader.GetFunc("EVP_DecryptUpdate");
auto ret =
EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast<int32_t>(cipher_data.size()));
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
return false;
}
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
(int (*)(EVP_CIPHER_CTX *, int, int, void *))loader.GetFunc("EVP_CIPHER_CTX_ctrl");
if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
return false;
}
int32_t mlen = 0;
int (*EVP_DecryptFinal_ex)(EVP_CIPHER_CTX *, unsigned char *, int *) =
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *))loader.GetFunc("EVP_DecryptFinal_ex");
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
return false;
}
*plain_len += mlen;
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
EVP_CIPHER_CTX_free(ctx);
iv.assign(iv.size(), 0);
return true;
}
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
const size_t data_size, const Byte *key, const size_t key_len,
const std::string &dec_mode) {
if (model_data == nullptr) {
MS_LOG(ERROR) << "model_data is nullptr.";
return nullptr;
}
if (key == nullptr) {
MS_LOG(ERROR) << "key is nullptr.";
return nullptr;
}
if (decrypt_len == nullptr) {
MS_LOG(ERROR) << "decrypt_len is nullptr.";
return nullptr;
}
auto ret = loader.Open(lib_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "loader open failed.";
return nullptr;
}
std::vector<char> block_buf;
std::vector<char> int_buf(sizeof(int32_t));
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE);
auto decrypt_data = std::make_unique<Byte[]>(data_size);
int32_t decrypt_block_len;
size_t offset = 0;
*decrypt_len = 0;
if (dec_mode != "AES-GCM") {
MS_LOG(ERROR) << "dec_mode only support AES-GCM.";
return nullptr;
}
if (key_len != Byte16) {
MS_LOG(ERROR) << "key_len only support 16.";
return nullptr;
}
while (offset < data_size) {
if (offset + sizeof(int32_t) > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
offset += int_buf.size();
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
if (static_cast<unsigned int>(cipher_flag) != MAGIC_NUM) {
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
return nullptr;
}
unsigned char tag[Byte16];
if (offset + Byte16 > data_size) {
MS_LOG(ERROR) << "buffer is invalid.";
return nullptr;
}
memcpy(tag, model_data + offset, Byte16);
offset += Byte16;
if (offset + sizeof(int32_t) > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
offset += int_buf.size();
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
if (block_size <= 0) {
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
return nullptr;
}
if (offset + block_size > data_size) {
MS_LOG(ERROR) << "assign len is invalid.";
return nullptr;
}
block_buf.assign(model_data + offset, model_data + offset + block_size);
offset += block_buf.size();
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
return nullptr;
}
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), static_cast<size_t>(decrypt_block_len));
*decrypt_len += static_cast<size_t>(decrypt_block_len);
}
ret = loader.Close();
if (ret != RET_OK) {
MS_LOG(ERROR) << "loader close failed.";
return nullptr;
}
return decrypt_data;
}
#endif
} // namespace mindspore::lite

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_DECRYPT_H_
#define MINDSPORE_CORE_UTILS_DECRYPT_H_
#include <string>
#include <memory>
typedef unsigned char Byte;
namespace mindspore::lite {
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
const size_t data_size, const Byte *key, const size_t key_len,
const std::string &dec_mode);
} // namespace mindspore::lite
#endif

View File

@ -30,10 +30,13 @@ namespace mindspore {
namespace lite { namespace lite {
int DynamicLibraryLoader::Open(const std::string &lib_path) { int DynamicLibraryLoader::Open(const std::string &lib_path) {
if (handler_ != nullptr) { if (handler_ != nullptr) {
return RET_ERROR; return RET_OK;
} }
std::string real_path = RealPath(lib_path.c_str()); std::string real_path = RealPath(lib_path.c_str());
if (real_path.empty()) {
MS_LOG(ERROR) << "real_path is invalid.";
return RET_ERROR;
}
#ifndef _WIN32 #ifndef _WIN32
#ifndef ENABLE_ARM #ifndef ENABLE_ARM
handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND); handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND);

View File

@ -0,0 +1,99 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/storage.h"
#include <sys/stat.h>
#ifndef _MSC_VER
#include <unistd.h>
#endif
#include "flatbuffers/flatbuffers.h"
#include "src/common/log_adapter.h"
#include "src/common/file_utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kMaxNum1024 = 1024;
}
int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) {
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, &graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
int size = builder.GetSize();
auto content = builder.GetBufferPointer();
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR;
}
std::string filename = outputPath;
if (filename.length() == 0) {
MS_LOG(ERROR) << "Invalid output path.";
return RET_ERROR;
}
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
filename = filename + ".ms";
}
#ifndef _MSC_VER
if (access(filename.c_str(), F_OK) == 0) {
chmod(filename.c_str(), S_IWUSR);
}
#endif
std::ofstream output(filename, std::ofstream::binary);
if (!output.is_open()) {
MS_LOG(ERROR) << "Can not open output file: " << filename;
return RET_ERROR;
}
output.write((const char *)content, size);
if (output.bad()) {
output.close();
MS_LOG(ERROR) << "Write output file : " << filename << " failed";
return RET_ERROR;
}
output.close();
#ifndef _MSC_VER
chmod(filename.c_str(), S_IRUSR);
#endif
return RET_OK;
}
schema::MetaGraphT *Storage::Load(const std::string &inputPath) {
size_t size = 0;
std::string filename = inputPath;
if (filename.length() == 0) {
MS_LOG(ERROR) << "Invalid input path.";
return nullptr;
}
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
filename = filename + ".ms";
}
auto buf = ReadFile(filename.c_str(), &size);
if (buf == nullptr) {
MS_LOG(ERROR) << "the file buffer is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph";
return nullptr;
}
auto graphDefT = schema::UnPackMetaGraph(buf);
return graphDefT.release();
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_COMMON_STORAGE_H
#define MINDSPORE_LITE_SRC_COMMON_STORAGE_H
#include <fstream>
#include <string>
#include "include/errorcode.h"
#include "flatbuffers/flatbuffers.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
class Storage {
public:
static int Save(const schema::MetaGraphT &graph, const std::string &outputPath);
static schema::MetaGraphT *Load(const std::string &inputPath);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H

View File

@ -30,16 +30,74 @@
#include "src/cxx_api/callback/callback_adapter.h" #include "src/cxx_api/callback/callback_adapter.h"
#include "src/cxx_api/callback/callback_impl.h" #include "src/cxx_api/callback/callback_impl.h"
#include "src/cxx_api/model/model_impl.h" #include "src/cxx_api/model/model_impl.h"
#ifdef ENABLE_OPENSSL
#include "src/common/decrypt.h"
#include "src/common/file_utils.h"
#endif
namespace mindspore { namespace mindspore {
std::mutex g_impl_init_lock; std::mutex g_impl_init_lock;
#ifdef ENABLE_OPENSSL
Status DecryptModel(const std::string &cropto_lib_path, const void *model_buf, size_t model_size, const Key &dec_key,
const std::string &dec_mode, std::unique_ptr<Byte[]> *decrypt_buffer, size_t *decrypt_len) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "model_buf is nullptr.";
return kLiteError;
}
*decrypt_len = 0;
*decrypt_buffer = lite::Decrypt(cropto_lib_path, decrypt_len, reinterpret_cast<const Byte *>(model_buf), model_size,
dec_key.key, dec_key.len, dec_mode);
if (*decrypt_buffer == nullptr || *decrypt_len == 0) {
MS_LOG(ERROR) << "Decrypt buffer failed";
return kLiteError;
}
return kSuccess;
}
#endif
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::vector<char> &dec_mode) { const std::string &cropto_lib_path) {
#ifdef ENABLE_OPENSSL
if (impl_ == nullptr) { if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
if (dec_key.len > 0) {
std::unique_ptr<Byte[]> decrypt_buffer;
size_t decrypt_len = 0;
Status ret = DecryptModel(cropto_lib_path, model_data, data_size, dec_key, dec_mode, &decrypt_buffer, &decrypt_len);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Decrypt model failed.";
return ret;
}
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
return ret;
}
} else {
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
if (ret != kSuccess) {
return ret;
}
}
return kSuccess;
#else
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
return kLiteError;
#endif
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context) {
if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError; return kLiteFileError;
@ -54,11 +112,59 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
} }
Status Model::Build(const std::vector<char> &model_path, ModelType model_type, Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::vector<char> &dec_mode) { const std::vector<char> &cropto_lib_path) {
#ifdef ENABLE_OPENSSL
if (impl_ == nullptr) { if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError;
}
}
if (dec_key.len > 0) {
size_t model_size;
auto model_buf = lite::ReadFile(model_path.data(), &model_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Read model file failed";
return kLiteError;
}
std::unique_ptr<Byte[]> decrypt_buffer;
size_t decrypt_len = 0;
Status ret = DecryptModel(CharToString(cropto_lib_path), model_buf, model_size, dec_key, dec_mode, &decrypt_buffer,
&decrypt_len);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Decrypt model failed.";
delete[] model_buf;
return ret;
}
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
delete[] model_buf;
return ret;
}
delete[] model_buf;
} else {
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
return ret;
}
}
return kSuccess;
#else
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
return kLiteError;
#endif
}
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context) {
if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError; return kLiteFileError;
@ -77,7 +183,7 @@ Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_conte
std::stringstream err_msg; std::stringstream err_msg;
if (impl_ == nullptr) { if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteFileError; return kLiteFileError;
@ -258,7 +364,7 @@ Status Model::LoadConfig(const std::vector<char> &config_path) {
return Status(kLiteFileError, "Illegal operation."); return Status(kLiteFileError, "Illegal operation.");
} }
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return Status(kLiteFileError, "Fail to load config file."); return Status(kLiteFileError, "Fail to load config file.");
@ -276,7 +382,7 @@ Status Model::UpdateConfig(const std::vector<char> &section,
const std::pair<std::vector<char>, std::vector<char>> &config) { const std::pair<std::vector<char>, std::vector<char>> &config) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ == nullptr) { if (impl_ == nullptr) {
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::make_shared<ModelImpl>();
} }
if (impl_ != nullptr) { if (impl_ != nullptr) {
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)}); return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
@ -388,5 +494,4 @@ float Model::GetLearningRate() {
} }
return impl_->GetLearningRate(); return impl_->GetLearningRate();
} }
} // namespace mindspore } // namespace mindspore

View File

@ -26,7 +26,7 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/train/train_utils.h" #include "src/train/train_utils.h"
#include "src/common/quant_utils.h" #include "src/common/quant_utils.h"
#include "tools/common/meta_graph_serializer.h" #include "src/common/storage.h"
#include "src/train/graph_fusion.h" #include "src/train/graph_fusion.h"
#include "src/train/graph_dropout.h" #include "src/train/graph_dropout.h"
#include "src/weight_decoder.h" #include "src/weight_decoder.h"
@ -553,7 +553,7 @@ int TrainExport::ExportInit(const std::string model_name, std::string version) {
return RET_OK; return RET_OK;
} }
int TrainExport::SaveToFile() { return MetaGraphSerializer::Save(*meta_graph_, file_name_); } int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
bool TrainExport::IsInputTensor(const schema::TensorT &t) { bool TrainExport::IsInputTensor(const schema::TensorT &t) {
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>()); int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());

View File

@ -1 +1 @@
844020 848116

View File

@ -69,6 +69,7 @@ constexpr int kNumPrintMin = 5;
constexpr const char *DELIM_COLON = ":"; constexpr const char *DELIM_COLON = ":";
constexpr const char *DELIM_COMMA = ","; constexpr const char *DELIM_COMMA = ",";
constexpr const char *DELIM_SLASH = "/"; constexpr const char *DELIM_SLASH = "/";
constexpr size_t kEncMaxLen = 16;
extern const std::unordered_map<int, std::string> kTypeIdMap; extern const std::unordered_map<int, std::string> kTypeIdMap;
extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
@ -139,6 +140,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1); AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1);
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes", AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
AddFlag(&BenchmarkFlags::decrypt_key_str_, "decryptKey",
"The key used to decrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
"length is 16.",
"");
AddFlag(&BenchmarkFlags::crypto_lib_path_, "cryptoLibPath", "Pass the crypto library path.", "");
AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false", AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
false); false);
AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict", AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict",
@ -192,6 +198,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
std::string perf_event_ = "CYCLE"; std::string perf_event_ = "CYCLE";
bool dump_tensor_data_ = false; bool dump_tensor_data_ = false;
bool print_tensor_data_ = false; bool print_tensor_data_ = false;
std::string decrypt_key_str_;
std::string dec_mode_ = "AES-GCM";
std::string crypto_lib_path_;
}; };
class MS_API BenchmarkBase { class MS_API BenchmarkBase {

View File

@ -698,7 +698,6 @@ int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name,
*total_size += 1; *total_size += 1;
return RET_OK; return RET_OK;
} }
int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor, int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor,
float *total_cosine_distance, int *total_size) { float *total_cosine_distance, int *total_size) {
if (tensor == nullptr) { if (tensor == nullptr) {
@ -1044,6 +1043,33 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
} }
#endif #endif
int BenchmarkUnifiedApi::CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context,
const std::string &model_name) {
Key dec_key;
if (!flags_->decrypt_key_str_.empty()) {
dec_key.len = lite::Hex2ByteArray(flags_->decrypt_key_str_, dec_key.key, kEncMaxLen);
if (dec_key.len == 0) {
MS_LOG(ERROR) << "dec_key.len == 0";
return RET_INPUT_PARAM_INVALID;
}
flags_->decrypt_key_str_.clear();
}
Status ret;
if (flags_->crypto_lib_path_.empty()) {
ret = ms_model_.Build(flags_->model_file_, model_type, context);
} else {
ret =
ms_model_.Build(flags_->model_file_, model_type, context, dec_key, flags_->dec_mode_, flags_->crypto_lib_path_);
}
memset(dec_key.key, 0, kEncMaxLen);
if (ret != kSuccess) {
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
return RET_ERROR;
}
return RET_OK;
}
int BenchmarkUnifiedApi::RunBenchmark() { int BenchmarkUnifiedApi::RunBenchmark() {
auto start_prepare_time = GetTimeUs(); auto start_prepare_time = GetTimeUs();
@ -1098,19 +1124,17 @@ int BenchmarkUnifiedApi::RunBenchmark() {
} }
#endif #endif
auto ret = ms_model_.Build(flags_->model_file_, model_type, context); status = CompileGraph(model_type, context, model_name);
if (ret != kSuccess) { if (status != RET_OK) {
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str(); MS_LOG(ERROR) << "Compile graph failed.";
std::cout << "ms_model_.Build failed while running ", model_name.c_str(); return status;
return RET_ERROR;
} }
if (!flags_->resize_dims_.empty()) { if (!flags_->resize_dims_.empty()) {
std::vector<std::vector<int64_t>> resize_dims; std::vector<std::vector<int64_t>> resize_dims;
(void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims), (void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims),
[&](auto &shapes) { return this->ConverterToInt64Vector<int>(shapes); }); [&](auto &shapes) { return this->ConverterToInt64Vector<int>(shapes); });
ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims); auto ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims);
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Input tensor resize failed."; MS_LOG(ERROR) << "Input tensor resize failed.";
std::cout << "Input tensor resize failed."; std::cout << "Input tensor resize failed.";

View File

@ -62,6 +62,8 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
float *total_cosine_distance, int *total_size); float *total_cosine_distance, int *total_size);
void InitContext(const std::shared_ptr<mindspore::Context> &context); void InitContext(const std::shared_ptr<mindspore::Context> &context);
int CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context, const std::string &model_name);
#ifdef ENABLE_OPENGL_TEXTURE #ifdef ENABLE_OPENGL_TEXTURE
int GenerateGLTexture(std::map<std::string, GLuint> *inputGlTexture); int GenerateGLTexture(std::map<std::string, GLuint> *inputGlTexture);

View File

@ -206,7 +206,8 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT
return true; return true;
} }
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT) { bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
// serialize model // serialize model
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT); auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
@ -214,7 +215,7 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
schema::FinishMetaGraphBuffer(builder, offset); schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize(); size_t size = builder.GetSize();
auto content = builder.GetBufferPointer(); auto content = builder.GetBufferPointer();
if (!SerializeModel(content, size)) { if (!SerializeModel(content, size, key, key_len, enc_mode)) {
MS_LOG(ERROR) << "Serialize graph failed"; MS_LOG(ERROR) << "Serialize graph failed";
return false; return false;
} }
@ -238,7 +239,8 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
return true; return true;
} }
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path) { int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
const size_t key_len, const std::string &enc_mode) {
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
auto offset = schema::MetaGraph::Pack(builder, &graph); auto offset = schema::MetaGraph::Pack(builder, &graph);
builder.Finish(offset); builder.Finish(offset);
@ -255,7 +257,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
return RET_ERROR; return RET_ERROR;
} }
if (save_together) { if (save_together) {
if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size)) { if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size, key, key_len, enc_mode)) {
MS_LOG(ERROR) << "Serialize graph failed"; MS_LOG(ERROR) << "Serialize graph failed";
return RET_ERROR; return RET_ERROR;
} }
@ -264,7 +266,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
MS_LOG(ERROR) << "Serialize graph weight failed"; MS_LOG(ERROR) << "Serialize graph weight failed";
return RET_ERROR; return RET_ERROR;
} }
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph)) { if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) {
MS_LOG(ERROR) << "Serialize graph and adjust weight failed"; MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
return RET_ERROR; return RET_ERROR;
} }
@ -283,14 +285,25 @@ MetaGraphSerializer::~MetaGraphSerializer() {
} }
} }
bool MetaGraphSerializer::SerializeModel(const void *content, size_t size) { bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
const std::string &enc_mode) {
MS_ASSERT(model_fs_ != nullptr); MS_ASSERT(model_fs_ != nullptr);
if (size == 0 || content == nullptr) { if (size == 0 || content == nullptr) {
MS_LOG(ERROR) << "Input meta graph buffer is nullptr"; MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
return false; return false;
} }
if (key_len > 0) {
model_fs_->write((const char *)content, static_cast<int64_t>(size)); size_t encrypt_len;
auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast<const Byte *>(content), size, key, key_len, enc_mode);
if (encrypt_content == nullptr || encrypt_len == 0) {
MS_LOG(ERROR) << "Encrypt failed.";
model_fs_->close();
return RET_ERROR;
}
model_fs_->write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
} else {
model_fs_->write((const char *)content, static_cast<int64_t>(size));
}
if (model_fs_->bad()) { if (model_fs_->bad()) {
MS_LOG(ERROR) << "Write model file failed: " << save_model_path_; MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
return RET_ERROR; return RET_ERROR;

View File

@ -21,12 +21,14 @@
#include <string> #include <string>
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "utils/crypto.h"
namespace mindspore::lite { namespace mindspore::lite {
class MetaGraphSerializer { class MetaGraphSerializer {
public: public:
// save serialized fb model // save serialized fb model
static int Save(const schema::MetaGraphT &graph, const std::string &output_path); static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {},
const size_t key_len = 0, const std::string &enc_mode = "");
private: private:
MetaGraphSerializer() = default; MetaGraphSerializer() = default;
@ -41,9 +43,11 @@ class MetaGraphSerializer {
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph); bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT); bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len,
const std::string &enc_mode);
bool SerializeModel(const void *content, size_t size); bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
const std::string &enc_mode);
private: private:
int64_t cur_offset_ = 0; int64_t cur_offset_ = 0;

View File

@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <string> #include <string>
#include <regex>
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -126,5 +127,42 @@ bool ConvertDoubleVector(const std::string &str, std::vector<double> *value) {
} }
return true; return true;
} }
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) {
std::regex r("[0-9a-fA-F]+");
if (!std::regex_match(hex_str, r)) {
MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
return 0;
}
if (hex_str.size() % 2 == 1) { // Mod 2 determines whether it is odd
MS_LOG(ERROR) << "the hexadecimal dec_key length must be even";
return 0;
}
size_t byte_len = hex_str.size() / 2; // Two hexadecimal characters represent a byte
if (byte_len > max_len) {
MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: " << max_len;
return 0;
}
constexpr int32_t a_val = 10; // The value of 'A' in hexadecimal is 10
constexpr size_t half_byte_offset = 4;
for (size_t i = 0; i < byte_len; ++i) {
size_t p = i * 2; // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters
if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
byte_array[i] = hex_str[p] - 'a' + a_val;
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
byte_array[i] = hex_str[p] - 'A' + a_val;
} else {
byte_array[i] = hex_str[p] - '0';
}
if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val);
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val);
} else {
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0');
}
}
return byte_len;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -40,6 +40,8 @@ bool ConvertDoubleNum(const std::string &str, double *value);
bool ConvertBool(std::string str, bool *value); bool ConvertBool(std::string str, bool *value);
bool ConvertDoubleVector(const std::string &str, std::vector<double> *value); bool ConvertDoubleVector(const std::string &str, std::vector<double> *value);
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_ #endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_

View File

@ -343,4 +343,32 @@ int GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
} }
return RET_OK; return RET_OK;
} }
int GenerateRandomData(mindspore::MSTensor *tensor) {
MS_ASSERT(tensor != nullptr);
auto input_data = tensor->MutableData();
if (input_data == nullptr) {
MS_LOG(ERROR) << "MallocData for inTensor failed";
return RET_ERROR;
}
int status = RET_ERROR;
if (static_cast<TypeId>(tensor->DataType()) == kObjectTypeString) {
MSTensor *input = MSTensor::StringsToTensor(tensor->Name(), {"you're the best."});
if (input == nullptr) {
std::cerr << "StringsToTensor failed" << std::endl;
MS_LOG(ERROR) << "StringsToTensor failed";
return RET_ERROR;
}
*tensor = *input;
delete input;
} else {
status = GenerateRandomData(tensor->DataSize(), input_data, static_cast<int>(tensor->DataType()));
}
if (status != RET_OK) {
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
return status;
}
return RET_OK;
}
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -78,6 +78,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
int GenerateRandomData(mindspore::tensor::MSTensor *tensors); int GenerateRandomData(mindspore::tensor::MSTensor *tensors);
int GenerateRandomData(mindspore::MSTensor *tensors);
int GenerateRandomData(size_t size, void *data, int data_type); int GenerateRandomData(size_t size, void *data, int data_type);
template <typename T, typename Distribution> template <typename T, typename Distribution>

View File

@ -1,7 +1,9 @@
add_definitions(-DPRIMITIVE_WRITEABLE) add_definitions(-DPRIMITIVE_WRITEABLE)
add_definitions(-DUSE_GLOG) add_definitions(-DUSE_GLOG)
set(USE_GLOG on) set(USE_GLOG on)
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
add_compile_definitions(ENABLE_OPENSSL)
endif()
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
set(CCSRC_SRC set(CCSRC_SRC
@ -13,8 +15,8 @@ set(CCSRC_SRC
include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel) include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL) if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
endif() endif()
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
@ -50,11 +52,10 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc
${SRC_DIR}/common/quant_utils.cc ${SRC_DIR}/common/quant_utils.cc
${SRC_DIR}/common/dynamic_library_loader.cc ${SRC_DIR}/common/dynamic_library_loader.cc
${SRC_DIR}/train/train_populate_parameter.cc ${SRC_DIR}/train/train_populate_parameter.cc
${SRC_DIR}/common/config_file.cc
../optimizer/*.cc ../optimizer/*.cc
) )
@ -76,16 +77,20 @@ add_subdirectory(micro/coder)
if(MSLITE_ENABLE_ACL) if(MSLITE_ENABLE_ACL)
set(MODE_ASCEND_ACL ON) set(MODE_ASCEND_ACL ON)
include_directories(${TOP_DIR}/graphengine/inc/external)
include(${TOP_DIR}/cmake/dependency_graphengine.cmake) include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
add_subdirectory(adapter/acl) add_subdirectory(adapter/acl)
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
endif() endif()
set(API_SRC ${SRC_DIR}/cxx_api/context.cc) file(GLOB CXX_API_SRCS
if(MSLITE_ENABLE_ACL) ${SRC_DIR}/cxx_api/*.cc
list(APPEND API_SRC ${SRC_DIR}/cxx_api/kernel.cc) ${SRC_DIR}/cxx_api/model/*.cc
endif() ${SRC_DIR}/cxx_api/graph/*.cc
${SRC_DIR}/cxx_api/tensor/*.cc)
set(LITE_SRC ${API_SRC} set(LITE_SRC ${API_SRC}
${CXX_API_SRCS}
${SRC_DIR}/ops/ops_def.cc ${SRC_DIR}/ops/ops_def.cc
${SRC_DIR}/ops/ops_utils.cc ${SRC_DIR}/ops/ops_utils.cc
${SRC_DIR}/common/utils.cc ${SRC_DIR}/common/utils.cc
@ -97,6 +102,7 @@ set(LITE_SRC ${API_SRC}
${SRC_DIR}/common/log.cc ${SRC_DIR}/common/log.cc
${SRC_DIR}/common/prim_util.cc ${SRC_DIR}/common/prim_util.cc
${SRC_DIR}/common/tensor_util.cc ${SRC_DIR}/common/tensor_util.cc
${SRC_DIR}/common/decrypt.cc
${SRC_DIR}/runtime/allocator.cc ${SRC_DIR}/runtime/allocator.cc
${SRC_DIR}/runtime/inner_allocator.cc ${SRC_DIR}/runtime/inner_allocator.cc
${SRC_DIR}/runtime/runtime_allocator.cc ${SRC_DIR}/runtime/runtime_allocator.cc

View File

@ -33,9 +33,15 @@
#include "tools/converter/import/mindspore_importer.h" #include "tools/converter/import/mindspore_importer.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "tools/converter/micro/coder/coder.h" #include "tools/converter/micro/coder/coder.h"
#include "src/common/prim_util.h"
#include "src/common/version_manager.h"
#include "tools/common/tensor_util.h"
#include "include/api/model.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace { namespace {
constexpr size_t kMaxNum1024 = 1024;
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) { void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
MS_ASSERT(converter_parameters != nullptr); MS_ASSERT(converter_parameters != nullptr);
converter_parameters->fmk = flag.fmk; converter_parameters->fmk = flag.fmk;
@ -178,6 +184,90 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter
return meta_graph; return meta_graph;
} }
int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom_nodes) {
MS_CHECK_TRUE_MSG(meta_graph != nullptr && exist_custom_nodes != nullptr, RET_ERROR, "input params contain nullptr.");
flatbuffers::FlatBufferBuilder fbb(kMaxNum1024);
for (const auto &node : meta_graph->nodes) {
auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
if (prim == nullptr) {
MS_LOG(ERROR) << "get primitive failed.";
fbb.Clear();
return RET_ERROR;
}
if (IsCustomNode(prim, static_cast<int>(SCHEMA_CUR))) {
*exist_custom_nodes = true;
break;
}
}
fbb.Clear();
return RET_OK;
}
int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr<converter::Flags> &flags) {
if (flags->trainModel) {
MS_LOG(WARNING) << "train model dont support pre-infer.";
return RET_OK;
}
bool exist_custom_nodes = false;
auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes);
if (check_ret == RET_ERROR) {
MS_LOG(ERROR) << "CheckExistCustomOps failed.";
return RET_ERROR;
}
if (exist_custom_nodes) {
MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer.";
return RET_OK;
}
mindspore::Model model;
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, &meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
int size = builder.GetSize();
auto content = builder.GetBufferPointer();
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR;
}
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
MS_LOG(ERROR) << "New context failed while running ";
std::cerr << "New context failed while running " << std::endl;
return RET_ERROR;
}
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
auto &device_list = context->MutableDeviceInfo();
device_list.push_back(device_info);
auto ret = model.Build(content, size, kMindIR, context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build error ";
std::cerr << "Build error " << std::endl;
return RET_ERROR;
}
for (auto &tensor : model.GetInputs()) {
if (tensor.Shape().empty() || tensor.DataSize() <= 0 ||
std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
return RET_OK;
}
auto status = GenerateRandomData(&tensor);
if (status != RET_OK) {
MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
return status;
}
}
std::vector<MSTensor> outputs;
ret = model.Predict(model.GetInputs(), &outputs);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Inference error ";
std::cerr << "Inference error " << std::endl;
return RET_ERROR;
}
return RET_OK;
}
int RunConverter(int argc, const char **argv) { int RunConverter(int argc, const char **argv) {
std::ostringstream oss; std::ostringstream oss;
auto flags = std::make_unique<converter::Flags>(); auto flags = std::make_unique<converter::Flags>();
@ -215,6 +305,18 @@ int RunConverter(int argc, const char **argv) {
// save graph to file // save graph to file
meta_graph->version = Version(); meta_graph->version = Version();
if (flags->infer) {
status = PreInference(*meta_graph, flags);
if (status != RET_OK) {
oss.clear();
oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status);
MS_LOG(ERROR) << oss.str();
std::cout << oss.str() << std::endl;
delete meta_graph;
return status;
}
}
if (flags->microParam.enable_micro) { if (flags->microParam.enable_micro) {
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode, status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode,
flags->microParam.target, flags->microParam.support_parallel, flags->microParam.target, flags->microParam.support_parallel,
@ -228,7 +330,7 @@ int RunConverter(int argc, const char **argv) {
return status; return status;
} }
} else { } else {
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile); status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode);
if (status != RET_OK) { if (status != RET_OK) {
delete meta_graph; delete meta_graph;
oss.clear(); oss.clear();
@ -238,7 +340,12 @@ int RunConverter(int argc, const char **argv) {
return status; return status;
} }
} }
// clear key
status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen);
if (status != EOK) {
MS_LOG(ERROR) << "memset failed.";
return RET_ERROR;
}
delete meta_graph; delete meta_graph;
oss.clear(); oss.clear();
oss << "CONVERT RESULT SUCCESS:" << status; oss << "CONVERT RESULT SUCCESS:" << status;

View File

@ -83,6 +83,20 @@ Flags::Flags() {
""); "");
AddFlag(&Flags::graphInputFormatStr, "inputDataFormat", AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
"Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC"); "Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
#ifdef ENABLE_OPENSSL
AddFlag(&Flags::encryptionStr, "encryption",
"Whether to export the encryption model."
"true | false",
"true");
AddFlag(&Flags::encKeyStr, "encryptKey",
"The key used to encrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
"length is 16.",
"");
#endif
AddFlag(&Flags::inferStr, "infer",
"Whether to do pre-inference after convert."
"true | false",
"false");
} }
int Flags::InitInputOutputDataType() { int Flags::InitInputOutputDataType() {
@ -310,8 +324,56 @@ int Flags::InitConfigFile() {
return RET_OK; return RET_OK;
} }
int Flags::Init(int argc, const char **argv) { int Flags::InitSaveFP16() {
int ret; if (saveFP16Str == "on") {
saveFP16 = true;
} else if (saveFP16Str == "off") {
saveFP16 = false;
} else {
std::cerr << "Init save_fp16 failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitPreInference() {
if (this->inferStr == "true") {
this->infer = true;
} else if (this->inferStr == "false") {
this->infer = false;
} else {
std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitEncrypt() {
if (this->encryptionStr == "true") {
this->encryption = true;
} else if (this->encryptionStr == "false") {
this->encryption = false;
} else {
std::cerr << "INPUT ILLEGAL: encryption must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if (this->encryption) {
if (encKeyStr.empty()) {
MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false.";
return RET_INPUT_PARAM_INVALID;
}
keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen);
if (keyLen != kEncMaxLen) {
MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters "
<< " and only support AES-GCM method and the key length is 16.";
return RET_INPUT_PARAM_INVALID;
}
encKeyStr.clear();
}
return RET_OK;
}
int Flags::PreInit(int argc, const char **argv) {
if (argc == 1) { if (argc == 1) {
std::cout << this->Usage() << std::endl; std::cout << this->Usage() << std::endl;
return lite::RET_SUCCESS_EXIT; return lite::RET_SUCCESS_EXIT;
@ -353,19 +415,23 @@ int Flags::Init(int argc, const char **argv) {
} }
if (!this->configFile.empty()) { if (!this->configFile.empty()) {
ret = InitConfigFile(); auto ret = InitConfigFile();
if (ret != RET_OK) { if (ret != RET_OK) {
std::cerr << "Init config file failed." << std::endl; std::cerr << "Init config file failed." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
} }
return RET_OK;
}
if (saveFP16Str == "on") { int Flags::Init(int argc, const char **argv) {
saveFP16 = true; auto ret = PreInit(argc, argv);
} else if (saveFP16Str == "off") { if (ret != RET_OK) {
saveFP16 = false; return ret;
} else { }
std::cerr << "Init save_fp16 failed." << std::endl; ret = InitSaveFP16();
if (ret != RET_OK) {
std::cerr << "Init save fp16 failed." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
@ -398,8 +464,25 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "Init graph input format failed." << std::endl; std::cerr << "Init graph input format failed." << std::endl;
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
ret = InitEncrypt();
if (ret != RET_OK) {
std::cerr << "Init encrypt failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitPreInference();
if (ret != RET_OK) {
std::cerr << "Init pre inference failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK; return RET_OK;
} }
Flags::~Flags() {
dec_key.clear();
encKeyStr.clear();
memset(encKey, 0, kEncMaxLen);
}
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) { bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
// device: [device0 device1] ---> {cpu, gpu} // device: [device0 device1] ---> {cpu, gpu}

View File

@ -40,6 +40,7 @@ constexpr auto kMaxSplitRatio = 10;
constexpr auto kComputeRate = "computeRate"; constexpr auto kComputeRate = "computeRate";
constexpr auto kSplitDevice0 = "device0"; constexpr auto kSplitDevice0 = "device0";
constexpr auto kSplitDevice1 = "device1"; constexpr auto kSplitDevice1 = "device1";
constexpr size_t kEncMaxLen = 16;
struct ParallelSplitConfig { struct ParallelSplitConfig {
ParallelSplitType parallel_split_type_ = SplitNo; ParallelSplitType parallel_split_type_ = SplitNo;
std::vector<int64_t> parallel_compute_rates_; std::vector<int64_t> parallel_compute_rates_;
@ -50,7 +51,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
public: public:
Flags(); Flags();
~Flags() override = default; ~Flags() override;
int InitInputOutputDataType(); int InitInputOutputDataType();
@ -66,8 +67,16 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser); int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
int InitEncrypt();
int InitPreInference();
int InitSaveFP16();
int Init(int argc, const char **argv); int Init(int argc, const char **argv);
int PreInit(int argc, const char **argv);
std::string modelFile; std::string modelFile;
std::string outputFile; std::string outputFile;
std::string fmkIn; std::string fmkIn;
@ -91,7 +100,19 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string graphInputFormatStr; std::string graphInputFormatStr;
std::string device; std::string device;
mindspore::Format graphInputFormat = mindspore::NHWC; mindspore::Format graphInputFormat = mindspore::NHWC;
bool enable_micro = false; std::string encKeyStr;
std::string encMode = "AES-GCM";
std::string inferStr;
#ifdef ENABLE_OPENSSL
std::string encryptionStr = "true";
bool encryption = true;
#else
std::string encryptionStr = "false";
bool encryption = false;
#endif
bool infer = false;
unsigned char encKey[kEncMaxLen];
size_t keyLen = 0;
lite::quant::CommonQuantParam commonQuantParam; lite::quant::CommonQuantParam commonQuantParam;
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;