forked from mindspore-Ecosystem/mindspore
add encryption to lite
This commit is contained in:
parent
c337e64241
commit
f670a635f0
|
@ -12,17 +12,66 @@ else()
|
||||||
set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl)
|
set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
if(BUILD_LITE)
|
||||||
mindspore_add_pkg(openssl
|
if(PLATFORM_ARM64 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||||
VER 1.1.1k
|
set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
|
||||||
LIBS ssl crypto
|
set(PATH
|
||||||
URL ${REQ_URL}
|
${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
|
||||||
MD5 ${MD5}
|
${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
|
||||||
CONFIGURE_COMMAND ./config no-zlib no-shared
|
$ENV{PATH})
|
||||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
mindspore_add_pkg(openssl
|
||||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
VER 1.1.1k
|
||||||
)
|
LIBS ssl crypto
|
||||||
include_directories(${openssl_INC})
|
URL ${REQ_URL}
|
||||||
add_library(mindspore::ssl ALIAS openssl::ssl)
|
MD5 ${MD5}
|
||||||
add_library(mindspore::crypto ALIAS openssl::crypto)
|
CONFIGURE_COMMAND ./Configure android-arm64 -D__ANDROID_API__=29 no-zlib
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||||
|
)
|
||||||
|
elseif(PLATFORM_ARM32 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||||
|
set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
|
||||||
|
set(PATH
|
||||||
|
${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
|
||||||
|
${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
|
||||||
|
$ENV{PATH})
|
||||||
|
mindspore_add_pkg(openssl
|
||||||
|
VER 1.1.1k
|
||||||
|
LIBS ssl crypto
|
||||||
|
URL ${REQ_URL}
|
||||||
|
MD5 ${MD5}
|
||||||
|
CONFIGURE_COMMAND ./Configure android-arm -D__ANDROID_API__=29 no-zlib
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||||
|
)
|
||||||
|
elseif(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
||||||
|
mindspore_add_pkg(openssl
|
||||||
|
VER 1.1.1k
|
||||||
|
LIBS ssl crypto
|
||||||
|
URL ${REQ_URL}
|
||||||
|
MD5 ${MD5}
|
||||||
|
CONFIGURE_COMMAND ./config no-zlib no-shared
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
MESSAGE(FATAL_ERROR "openssl does not support compilation for the current environment.")
|
||||||
|
endif()
|
||||||
|
include_directories(${openssl_INC})
|
||||||
|
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||||
|
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||||
|
else()
|
||||||
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
||||||
|
mindspore_add_pkg(openssl
|
||||||
|
VER 1.1.1k
|
||||||
|
LIBS ssl crypto
|
||||||
|
URL ${REQ_URL}
|
||||||
|
MD5 ${MD5}
|
||||||
|
CONFIGURE_COMMAND ./config no-zlib no-shared
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||||
|
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||||
|
)
|
||||||
|
include_directories(${openssl_INC})
|
||||||
|
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||||
|
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -173,7 +173,7 @@ class MS_API Model {
|
||||||
/// \return Status of operation
|
/// \return Status of operation
|
||||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||||
|
|
||||||
/// \brief Obtains optimizer params tensors of the model.
|
/// \brief Obtains optimizer params tensors of the model.
|
||||||
///
|
///
|
||||||
/// \return The vector that includes all params tensors.
|
/// \return The vector that includes all params tensors.
|
||||||
std::vector<MSTensor> GetOptimizerParams() const;
|
std::vector<MSTensor> GetOptimizerParams() const;
|
||||||
|
@ -256,17 +256,14 @@ class MS_API Model {
|
||||||
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
|
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||||
///
|
///
|
||||||
/// \param[in] model_data Define the buffer read from a model file.
|
/// \param[in] model_data Define the buffer read from a model file.
|
||||||
/// \param[in] size Define bytes number of model buffer.
|
/// \param[in] data_size Define bytes number of model buffer.
|
||||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||||
/// ModelType::kMindIR is valid for Lite.
|
/// ModelType::kMindIR is valid for Lite.
|
||||||
/// \param[in] model_context Define the context used to store options during execution.
|
/// \param[in] model_context Define the context used to store options during execution.
|
||||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
|
||||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
|
||||||
///
|
///
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
const std::shared_ptr<Context> &model_context = nullptr);
|
||||||
const std::string &dec_mode = kDecModeAesGcm);
|
|
||||||
|
|
||||||
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||||
///
|
///
|
||||||
|
@ -274,13 +271,40 @@ class MS_API Model {
|
||||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||||
/// ModelType::kMindIR is valid for Lite.
|
/// ModelType::kMindIR is valid for Lite.
|
||||||
/// \param[in] model_context Define the context used to store options during execution.
|
/// \param[in] model_context Define the context used to store options during execution.
|
||||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
|
||||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
|
||||||
///
|
///
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
inline Status Build(const std::string &model_path, ModelType model_type,
|
Status Build(const std::string &model_path, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
const std::shared_ptr<Context> &model_context = nullptr);
|
||||||
const std::string &dec_mode = kDecModeAesGcm);
|
|
||||||
|
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \param[in] model_data Define the buffer read from a model file.
|
||||||
|
/// \param[in] data_size Define bytes number of model buffer.
|
||||||
|
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||||
|
/// ModelType::kMindIR is valid for Lite.
|
||||||
|
/// \param[in] model_context Define the context used to store options during execution.
|
||||||
|
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
|
||||||
|
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
|
||||||
|
/// \param[in] cropto_lib_path Define the openssl library path.
|
||||||
|
///
|
||||||
|
/// \return Status.
|
||||||
|
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||||
|
const std::string &cropto_lib_path);
|
||||||
|
|
||||||
|
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||||
|
///
|
||||||
|
/// \param[in] model_path Define the model path.
|
||||||
|
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||||
|
/// ModelType::kMindIR is valid for Lite.
|
||||||
|
/// \param[in] model_context Define the context used to store options during execution.
|
||||||
|
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
|
||||||
|
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
|
||||||
|
/// \param[in] cropto_lib_path Define the openssl library path.
|
||||||
|
///
|
||||||
|
/// \return Status.
|
||||||
|
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||||
|
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Serialization;
|
friend class Serialization;
|
||||||
|
@ -291,11 +315,10 @@ class MS_API Model {
|
||||||
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
||||||
Status LoadConfig(const std::vector<char> &config_path);
|
Status LoadConfig(const std::vector<char> &config_path);
|
||||||
Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config);
|
Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config);
|
||||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Build(const std::vector<char> &model_path, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
|
const std::shared_ptr<Context> &model_context);
|
||||||
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
const Key &dec_key, const std::string &dec_mode, const std::vector<char> &cropto_lib_path);
|
||||||
|
|
||||||
std::shared_ptr<ModelImpl> impl_;
|
std::shared_ptr<ModelImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -321,14 +344,15 @@ Status Model::UpdateConfig(const std::string §ion, const std::pair<std::stri
|
||||||
return UpdateConfig(StringToChar(section), config_pair);
|
return UpdateConfig(StringToChar(section), config_pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
inline Status Model::Build(const std::string &model_path, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||||
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
|
const std::string &dec_mode, const std::string &cropto_lib_path) {
|
||||||
|
return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
inline Status Model::Build(const std::string &model_path, ModelType model_type,
|
||||||
const Key &dec_key, const std::string &dec_mode) {
|
const std::shared_ptr<Context> &model_context) {
|
||||||
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
|
return Build(StringToChar(model_path), model_type, model_context);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||||
|
|
|
@ -52,14 +52,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
||||||
return impl_->Build();
|
return impl_->Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
|
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||||
const std::vector<char> &) {
|
const std::string &, const std::vector<char> &) {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return kMCFailed;
|
return kMCFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &) {
|
||||||
const std::vector<char> &) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return kMCFailed;
|
return kMCFailed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,7 +120,7 @@ bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work
|
||||||
}
|
}
|
||||||
|
|
||||||
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
|
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
|
||||||
bool is_encrypt) {
|
int iv_len, bool is_encrypt) {
|
||||||
constexpr int32_t key_length_16 = 16;
|
constexpr int32_t key_length_16 = 16;
|
||||||
constexpr int32_t key_length_24 = 24;
|
constexpr int32_t key_length_24 = 24;
|
||||||
constexpr int32_t key_length_32 = 32;
|
constexpr int32_t key_length_32 = 32;
|
||||||
|
@ -163,8 +163,35 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
|
||||||
int32_t ret = 0;
|
int32_t ret = 0;
|
||||||
auto ctx = EVP_CIPHER_CTX_new();
|
auto ctx = EVP_CIPHER_CTX_new();
|
||||||
if (is_encrypt) {
|
if (is_encrypt) {
|
||||||
|
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +210,7 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key,
|
bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key,
|
||||||
int32_t key_len, const std::string &enc_mode) {
|
int32_t key_len, const std::string &enc_mode, unsigned char *tag) {
|
||||||
size_t encrypt_data_buf_len = *encrypt_data_len;
|
size_t encrypt_data_buf_len = *encrypt_data_len;
|
||||||
int32_t cipher_len = 0;
|
int32_t cipher_len = 0;
|
||||||
int32_t iv_len = AES_BLOCK_SIZE;
|
int32_t iv_len = AES_BLOCK_SIZE;
|
||||||
|
@ -201,7 +228,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), true);
|
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()), true);
|
||||||
if (ctx == nullptr) {
|
if (ctx == nullptr) {
|
||||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -214,15 +241,19 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
||||||
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (work_mode == "CBC") {
|
int32_t flen = 0;
|
||||||
int32_t flen = 0;
|
ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
|
||||||
ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
|
if (ret_evp != 1) {
|
||||||
if (ret_evp != 1) {
|
MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
|
||||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
|
return false;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cipher_len += flen;
|
|
||||||
}
|
}
|
||||||
|
cipher_len += flen;
|
||||||
|
|
||||||
|
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, Byte16, tag) != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
EVP_CIPHER_CTX_free(ctx);
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
@ -266,7 +297,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
|
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
|
||||||
int32_t key_len, const std::string &dec_mode) {
|
int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
|
||||||
std::string alg_mode;
|
std::string alg_mode;
|
||||||
std::string work_mode;
|
std::string work_mode;
|
||||||
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||||
|
@ -277,7 +308,7 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
|
||||||
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), false);
|
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), iv.size(), false);
|
||||||
if (ctx == nullptr) {
|
if (ctx == nullptr) {
|
||||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -288,15 +319,20 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
|
||||||
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (work_mode == "CBC") {
|
|
||||||
int32_t mlen = 0;
|
if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
|
||||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||||
if (ret != 1) {
|
return false;
|
||||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*plain_len += mlen;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t mlen = 0;
|
||||||
|
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*plain_len += mlen;
|
||||||
|
|
||||||
EVP_CIPHER_CTX_free(ctx);
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -319,7 +355,9 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
|
||||||
size_t block_enc_len = block_enc_buf.size();
|
size_t block_enc_len = block_enc_buf.size();
|
||||||
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset);
|
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset);
|
||||||
block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size);
|
block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size);
|
||||||
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode)) {
|
unsigned char tag[Byte16];
|
||||||
|
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode,
|
||||||
|
tag)) {
|
||||||
MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -332,6 +370,13 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
|
||||||
}
|
}
|
||||||
*encrypt_len += sizeof(int32_t);
|
*encrypt_len += sizeof(int32_t);
|
||||||
|
|
||||||
|
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb
|
||||||
|
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, tag, Byte16);
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
|
||||||
|
}
|
||||||
|
*encrypt_len += Byte16;
|
||||||
|
|
||||||
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN);
|
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN);
|
||||||
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len);
|
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
|
@ -371,6 +416,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
|
||||||
MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted";
|
MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned char tag[Byte16];
|
||||||
|
fid.read(reinterpret_cast<char *>(tag), Byte16);
|
||||||
|
|
||||||
fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
|
fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
|
||||||
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||||
if (block_size < 0) {
|
if (block_size < 0) {
|
||||||
|
@ -379,7 +428,7 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
|
||||||
}
|
}
|
||||||
fid.read(block_buf.data(), static_cast<int64_t>(block_size));
|
fid.read(block_buf.data(), static_cast<int64_t>(block_size));
|
||||||
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||||
static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode))) {
|
static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||||
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -409,6 +458,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
*decrypt_len = 0;
|
*decrypt_len = 0;
|
||||||
while (offset < data_size) {
|
while (offset < data_size) {
|
||||||
|
if (offset + sizeof(int32_t) > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||||
offset += int_buf.size();
|
offset += int_buf.size();
|
||||||
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||||
|
@ -416,27 +469,44 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
|
||||||
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
|
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
unsigned char tag[Byte16];
|
||||||
|
if (offset + Byte16 > data_size) {
|
||||||
|
MS_LOG(ERROR) << "buffer is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto ret = memcpy_s(tag, Byte16, model_data + offset, Byte16);
|
||||||
|
if (ret != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
|
||||||
|
}
|
||||||
|
offset += Byte16;
|
||||||
|
if (offset + sizeof(int32_t) > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||||
offset += int_buf.size();
|
offset += int_buf.size();
|
||||||
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||||
if (block_size < 0) {
|
if (block_size <= 0) {
|
||||||
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
|
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (offset + block_size > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
block_buf.assign(model_data + offset, model_data + offset + block_size);
|
block_buf.assign(model_data + offset, model_data + offset + block_size);
|
||||||
offset += block_buf.size();
|
offset += block_buf.size();
|
||||||
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||||
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode))) {
|
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||||
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
|
ret = memcpy_s(decrypt_data.get() + *decrypt_len, data_size, decrypt_block_buf.data(),
|
||||||
auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(),
|
static_cast<size_t>(decrypt_block_len));
|
||||||
static_cast<size_t>(decrypt_block_len));
|
if (ret != EOK) {
|
||||||
if (ret != 0) {
|
MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
|
||||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*decrypt_len += static_cast<size_t>(decrypt_block_len);
|
*decrypt_len += static_cast<size_t>(decrypt_block_len);
|
||||||
}
|
}
|
||||||
return decrypt_data;
|
return decrypt_data;
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace mindspore {
|
||||||
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||||
constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info
|
constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info
|
||||||
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||||
|
constexpr size_t Byte16 = 16;
|
||||||
|
|
||||||
MS_CORE_API std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len,
|
MS_CORE_API std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len,
|
||||||
const Byte *key, size_t key_len, const std::string &enc_mode);
|
const Byte *key, size_t key_len, const std::string &enc_mode);
|
||||||
|
|
|
@ -34,7 +34,7 @@ option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
||||||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on)
|
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" off)
|
||||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||||
option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
|
option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
|
||||||
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
|
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
|
||||||
|
@ -127,7 +127,11 @@ if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT})
|
||||||
set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT})
|
set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT})
|
||||||
endif()
|
endif()
|
||||||
if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64) OR (PLATFORM_ARM AND ANDROID_NDK_TOOLCHAIN_INCLUDED))
|
||||||
|
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||||
|
else()
|
||||||
|
set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
|
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
|
||||||
|
@ -227,7 +231,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||||
endif()
|
endif()
|
||||||
set(MSLITE_ENABLE_RUNTIME_GLOG off)
|
set(MSLITE_ENABLE_RUNTIME_GLOG off)
|
||||||
set(MSLITE_ENABLE_RUNTIME_CONVERT off)
|
set(MSLITE_ENABLE_RUNTIME_CONVERT off)
|
||||||
#set for cross - compiling toolchain
|
#set for cross - compiling toolchain
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||||
|
@ -540,13 +544,15 @@ if(MSLITE_ENABLE_CONVERTER)
|
||||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||||
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
|
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
|
||||||
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
|
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
|
||||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
|
||||||
find_package(Patch)
|
|
||||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
|
||||||
endif()
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||||
|
find_package(Patch)
|
||||||
|
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||||
|
add_compile_definitions(ENABLE_OPENSSL)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MSLITE_ENABLE_MINDRT)
|
if(MSLITE_ENABLE_MINDRT)
|
||||||
add_compile_definitions(ENABLE_MINDRT)
|
add_compile_definitions(ENABLE_MINDRT)
|
||||||
endif()
|
endif()
|
||||||
|
@ -590,7 +596,7 @@ if(NOT PLATFORM_ARM)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full"
|
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full"
|
||||||
OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
|
OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
|
||||||
add_compile_definitions(ENABLE_ANDROID)
|
add_compile_definitions(ENABLE_ANDROID)
|
||||||
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
|
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
|
||||||
add_compile_definitions(ENABLE_MD_LITE_X86_64)
|
add_compile_definitions(ENABLE_MD_LITE_X86_64)
|
||||||
|
@ -605,7 +611,7 @@ endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops)
|
||||||
if(ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
if(ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||||
|
|
|
@ -206,6 +206,7 @@ build_lite() {
|
||||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off"
|
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off"
|
||||||
else
|
else
|
||||||
checkndk
|
checkndk
|
||||||
|
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
|
||||||
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
||||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on"
|
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on"
|
||||||
|
@ -237,6 +238,7 @@ build_lite() {
|
||||||
ARM64_COMPILE_CONVERTER=ON
|
ARM64_COMPILE_CONVERTER=ON
|
||||||
else
|
else
|
||||||
checkndk
|
checkndk
|
||||||
|
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
|
||||||
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}"
|
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}"
|
||||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
||||||
|
|
|
@ -58,19 +58,19 @@ public class Model {
|
||||||
/**
|
/**
|
||||||
* Build model.
|
* Build model.
|
||||||
*
|
*
|
||||||
* @param buffer model buffer.
|
* @param buffer model buffer.
|
||||||
* @param modelType model type.
|
* @param modelType model type.
|
||||||
* @param context model build context.
|
* @param context model build context.
|
||||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||||
|
* @param cropto_lib_path define the openssl library path.
|
||||||
* @return model build status.
|
* @return model build status.
|
||||||
*/
|
*/
|
||||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key,
|
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||||
String dec_mode) {
|
|
||||||
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
|
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||||
return modelPtr != 0;
|
return modelPtr != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ public class Model {
|
||||||
if (context == null || buffer == null) {
|
if (context == null || buffer == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "");
|
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", "");
|
||||||
return modelPtr != 0;
|
return modelPtr != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,18 +94,19 @@ public class Model {
|
||||||
/**
|
/**
|
||||||
* Build model.
|
* Build model.
|
||||||
*
|
*
|
||||||
* @param modelPath model path.
|
* @param modelPath model path.
|
||||||
* @param modelType model type.
|
* @param modelType model type.
|
||||||
* @param context model build context.
|
* @param context model build context.
|
||||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||||
|
* @param cropto_lib_path define the openssl library path.
|
||||||
* @return model build status.
|
* @return model build status.
|
||||||
*/
|
*/
|
||||||
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) {
|
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||||
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
|
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||||
return modelPtr != 0;
|
return modelPtr != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,7 +122,7 @@ public class Model {
|
||||||
if (context == null || modelPath == null) {
|
if (context == null || modelPath == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "");
|
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", "");
|
||||||
return modelPtr != 0;
|
return modelPtr != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,8 +257,7 @@ public class Model {
|
||||||
* @param outputTensorNames tensor name used for export inference graph.
|
* @param outputTensorNames tensor name used for export inference graph.
|
||||||
* @return Whether the export is successful.
|
* @return Whether the export is successful.
|
||||||
*/
|
*/
|
||||||
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,
|
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, List<String> outputTensorNames) {
|
||||||
List<String> outputTensorNames) {
|
|
||||||
if (fileName == null) {
|
if (fileName == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -355,10 +355,11 @@ public class Model {
|
||||||
|
|
||||||
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
|
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
|
||||||
|
|
||||||
private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod);
|
private native long buildByPath(String modelPath, int modelType, long contextPtr,
|
||||||
|
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||||
|
|
||||||
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key,
|
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr,
|
||||||
String dec_mod);
|
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||||
|
|
||||||
private native List<Long> getInputs(long modelPtr);
|
private native List<Long> getInputs(long modelPtr);
|
||||||
|
|
||||||
|
@ -380,8 +381,7 @@ public class Model {
|
||||||
|
|
||||||
private native boolean resize(long modelPtr, long[] inputs, int[][] dims);
|
private native boolean resize(long modelPtr, long[] inputs, int[][] dims);
|
||||||
|
|
||||||
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer,
|
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, String[] outputTensorNames);
|
||||||
String[] outputTensorNames);
|
|
||||||
|
|
||||||
private native List<Long> getFeatureMaps(long modelPtr);
|
private native List<Long> getFeatureMaps(long modelPtr);
|
||||||
|
|
||||||
|
@ -389,6 +389,5 @@ public class Model {
|
||||||
|
|
||||||
private native boolean setLearningRate(long modelPtr, float learning_rate);
|
private native boolean setLearningRate(long modelPtr, float learning_rate);
|
||||||
|
|
||||||
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate,
|
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
||||||
float momentum);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv
|
||||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
|
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
|
||||||
jobject model_buffer, jint model_type,
|
jobject model_buffer, jint model_type,
|
||||||
jlong context_ptr, jcharArray key_str,
|
jlong context_ptr, jcharArray key_str,
|
||||||
jstring dec_mod) {
|
jstring dec_mod, jstring cropto_lib_path) {
|
||||||
if (model_buffer == nullptr) {
|
if (model_buffer == nullptr) {
|
||||||
MS_LOGE("Buffer from java is nullptr");
|
MS_LOGE("Buffer from java is nullptr");
|
||||||
return reinterpret_cast<jlong>(nullptr);
|
return reinterpret_cast<jlong>(nullptr);
|
||||||
|
@ -116,7 +116,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
||||||
}
|
}
|
||||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||||
mindspore::Key dec_key{dec_key_data, key_len};
|
mindspore::Key dec_key{dec_key_data, key_len};
|
||||||
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod);
|
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||||
|
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||||
} else {
|
} else {
|
||||||
status = model->Build(model_buf, buffer_len, c_model_type, context);
|
status = model->Build(model_buf, buffer_len, c_model_type, context);
|
||||||
}
|
}
|
||||||
|
@ -130,7 +131,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
||||||
|
|
||||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
|
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
|
||||||
jint model_type, jlong context_ptr,
|
jint model_type, jlong context_ptr,
|
||||||
jcharArray key_str, jstring dec_mod) {
|
jcharArray key_str, jstring dec_mod,
|
||||||
|
jstring cropto_lib_path) {
|
||||||
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
|
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
|
||||||
mindspore::ModelType c_model_type;
|
mindspore::ModelType c_model_type;
|
||||||
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
|
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
|
||||||
|
@ -172,7 +174,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *
|
||||||
}
|
}
|
||||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||||
mindspore::Key dec_key{dec_key_data, key_len};
|
mindspore::Key dec_key{dec_key_data, key_len};
|
||||||
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod);
|
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||||
|
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||||
} else {
|
} else {
|
||||||
status = model->Build(c_model_path, c_model_type, context);
|
status = model->Build(c_model_path, c_model_type, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,6 +131,14 @@ set(LITE_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||||
|
set(LITE_SRC
|
||||||
|
${LITE_SRC}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/common/decrypt.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||||
set(LITE_SRC
|
set(LITE_SRC
|
||||||
${LITE_SRC}
|
${LITE_SRC}
|
||||||
|
@ -272,8 +280,7 @@ set(TRAIN_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
|
||||||
${TOOLS_DIR}/common/storage.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
|
||||||
${TOOLS_DIR}/common/meta_graph_serializer.cc
|
|
||||||
${TOOLS_DIR}/converter/optimizer.cc
|
${TOOLS_DIR}/converter/optimizer.cc
|
||||||
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
|
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
|
||||||
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc
|
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc
|
||||||
|
|
|
@ -0,0 +1,323 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/common/decrypt.h"
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
|
#include <openssl/aes.h>
|
||||||
|
#include <openssl/evp.h>
|
||||||
|
#include <openssl/rand.h>
|
||||||
|
#include <regex>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "src/common/dynamic_library_loader.h"
|
||||||
|
#include "src/common/file_utils.h"
|
||||||
|
#endif
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/common/log_util.h"
|
||||||
|
|
||||||
|
#ifndef SECUREC_MEM_MAX_LEN
|
||||||
|
#define SECUREC_MEM_MAX_LEN 0x7fffffffUL
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore::lite {
|
||||||
|
#ifndef ENABLE_OPENSSL
|
||||||
|
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *, const Byte *, const size_t, const Byte *,
|
||||||
|
const size_t, const std::string &) {
|
||||||
|
MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
|
||||||
|
"when the OPENSSL compilation option is enabled.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
namespace {
|
||||||
|
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||||
|
constexpr size_t Byte16 = 16; // Byte16
|
||||||
|
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||||
|
DynamicLibraryLoader loader;
|
||||||
|
} // namespace
|
||||||
|
int32_t ByteToInt(const Byte *byteArray, size_t length) {
|
||||||
|
if (byteArray == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (length < sizeof(int32_t)) {
|
||||||
|
MS_LOG(ERROR) << "Length of byteArray is " << length << ", less than sizeof(int32_t): 4.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return *(reinterpret_cast<const int32_t *>(byteArray));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector<Byte> *iv,
|
||||||
|
std::vector<Byte> *cipher_data) {
|
||||||
|
if (encrypt_data == nullptr || iv == nullptr || cipher_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||||
|
std::vector<Byte> int_buf(sizeof(int32_t));
|
||||||
|
if (sizeof(int32_t) > encrypt_len) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int_buf.assign(encrypt_data, encrypt_data + sizeof(int32_t));
|
||||||
|
auto iv_len = ByteToInt(int_buf.data(), int_buf.size());
|
||||||
|
if (iv_len <= 0 || iv_len + sizeof(int32_t) + sizeof(int32_t) > encrypt_len) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t));
|
||||||
|
auto cipher_len = ByteToInt(int_buf.data(), int_buf.size());
|
||||||
|
if (((static_cast<size_t>(iv_len) + sizeof(int32_t) + static_cast<size_t>(cipher_len) + sizeof(int32_t)) !=
|
||||||
|
encrypt_len)) {
|
||||||
|
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*iv).assign(encrypt_data + sizeof(int32_t), encrypt_data + sizeof(int32_t) + iv_len);
|
||||||
|
if (cipher_len <= 0 || sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len > encrypt_len) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*cipher_data)
|
||||||
|
.assign(encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t),
|
||||||
|
encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) {
|
||||||
|
if (alg_mode == nullptr || work_mode == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::smatch results;
|
||||||
|
std::regex re("([A-Z]{3})-([A-Z]{3})");
|
||||||
|
if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) {
|
||||||
|
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const size_t index_1 = 1;
|
||||||
|
const size_t index_2 = 2;
|
||||||
|
*alg_mode = results[index_1];
|
||||||
|
*work_mode = results[index_2];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
|
||||||
|
int iv_len) {
|
||||||
|
constexpr int32_t key_length_16 = 16;
|
||||||
|
constexpr int32_t key_length_24 = 24;
|
||||||
|
constexpr int32_t key_length_32 = 32;
|
||||||
|
const EVP_CIPHER *(*funcPtr)() = nullptr;
|
||||||
|
if (work_mode == "GCM") {
|
||||||
|
switch (key_len) {
|
||||||
|
case key_length_16:
|
||||||
|
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_128_gcm");
|
||||||
|
break;
|
||||||
|
case key_length_24:
|
||||||
|
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_192_gcm");
|
||||||
|
break;
|
||||||
|
case key_length_32:
|
||||||
|
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_256_gcm");
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t ret = 0;
|
||||||
|
EVP_CIPHER_CTX *(*EVP_CIPHER_CTX_new)() = (EVP_CIPHER_CTX * (*)()) loader.GetFunc("EVP_CIPHER_CTX_new");
|
||||||
|
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int (*EVP_DecryptInit_ex)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
|
||||||
|
const unsigned char *) =
|
||||||
|
(int (*)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
|
||||||
|
const unsigned char *))loader.GetFunc("EVP_DecryptInit_ex");
|
||||||
|
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||||
|
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
|
||||||
|
(int (*)(EVP_CIPHER_CTX * ctx, int type, int arg, void *ptr)) loader.GetFunc("EVP_CIPHER_CTX_ctrl");
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||||
|
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||||
|
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||||
|
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
|
||||||
|
int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
|
||||||
|
if (plain_data == nullptr || plain_len == nullptr || encrypt_data == nullptr || key == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string alg_mode;
|
||||||
|
std::string work_mode;
|
||||||
|
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<Byte> iv;
|
||||||
|
std::vector<Byte> cipher_data;
|
||||||
|
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()));
|
||||||
|
if (ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int (*EVP_DecryptUpdate)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int) =
|
||||||
|
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int))loader.GetFunc("EVP_DecryptUpdate");
|
||||||
|
auto ret =
|
||||||
|
EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast<int32_t>(cipher_data.size()));
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
|
||||||
|
(int (*)(EVP_CIPHER_CTX *, int, int, void *))loader.GetFunc("EVP_CIPHER_CTX_ctrl");
|
||||||
|
if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
|
||||||
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t mlen = 0;
|
||||||
|
int (*EVP_DecryptFinal_ex)(EVP_CIPHER_CTX *, unsigned char *, int *) =
|
||||||
|
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *))loader.GetFunc("EVP_DecryptFinal_ex");
|
||||||
|
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*plain_len += mlen;
|
||||||
|
|
||||||
|
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
iv.assign(iv.size(), 0);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
|
||||||
|
const size_t data_size, const Byte *key, const size_t key_len,
|
||||||
|
const std::string &dec_mode) {
|
||||||
|
if (model_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "model_data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (key == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "key is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (decrypt_len == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "decrypt_len is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto ret = loader.Open(lib_path);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "loader open failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<char> block_buf;
|
||||||
|
std::vector<char> int_buf(sizeof(int32_t));
|
||||||
|
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE);
|
||||||
|
auto decrypt_data = std::make_unique<Byte[]>(data_size);
|
||||||
|
int32_t decrypt_block_len;
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
*decrypt_len = 0;
|
||||||
|
if (dec_mode != "AES-GCM") {
|
||||||
|
MS_LOG(ERROR) << "dec_mode only support AES-GCM.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (key_len != Byte16) {
|
||||||
|
MS_LOG(ERROR) << "key_len only support 16.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
while (offset < data_size) {
|
||||||
|
if (offset + sizeof(int32_t) > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||||
|
offset += int_buf.size();
|
||||||
|
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||||
|
if (static_cast<unsigned int>(cipher_flag) != MAGIC_NUM) {
|
||||||
|
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
unsigned char tag[Byte16];
|
||||||
|
if (offset + Byte16 > data_size) {
|
||||||
|
MS_LOG(ERROR) << "buffer is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memcpy(tag, model_data + offset, Byte16);
|
||||||
|
offset += Byte16;
|
||||||
|
if (offset + sizeof(int32_t) > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||||
|
offset += int_buf.size();
|
||||||
|
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||||
|
if (block_size <= 0) {
|
||||||
|
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (offset + block_size > data_size) {
|
||||||
|
MS_LOG(ERROR) << "assign len is invalid.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
block_buf.assign(model_data + offset, model_data + offset + block_size);
|
||||||
|
offset += block_buf.size();
|
||||||
|
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||||
|
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||||
|
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), static_cast<size_t>(decrypt_block_len));
|
||||||
|
|
||||||
|
*decrypt_len += static_cast<size_t>(decrypt_block_len);
|
||||||
|
}
|
||||||
|
ret = loader.Close();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "loader close failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return decrypt_data;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_UTILS_DECRYPT_H_
|
||||||
|
#define MINDSPORE_CORE_UTILS_DECRYPT_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
typedef unsigned char Byte;
|
||||||
|
namespace mindspore::lite {
|
||||||
|
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
|
||||||
|
const size_t data_size, const Byte *key, const size_t key_len,
|
||||||
|
const std::string &dec_mode);
|
||||||
|
} // namespace mindspore::lite
|
||||||
|
#endif
|
|
@ -30,10 +30,13 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
int DynamicLibraryLoader::Open(const std::string &lib_path) {
|
int DynamicLibraryLoader::Open(const std::string &lib_path) {
|
||||||
if (handler_ != nullptr) {
|
if (handler_ != nullptr) {
|
||||||
return RET_ERROR;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
std::string real_path = RealPath(lib_path.c_str());
|
std::string real_path = RealPath(lib_path.c_str());
|
||||||
|
if (real_path.empty()) {
|
||||||
|
MS_LOG(ERROR) << "real_path is invalid.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
#ifndef ENABLE_ARM
|
#ifndef ENABLE_ARM
|
||||||
handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND);
|
handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND);
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/common/storage.h"
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
#include "flatbuffers/flatbuffers.h"
|
||||||
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/common/file_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kMaxNum1024 = 1024;
|
||||||
|
}
|
||||||
|
int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||||
|
auto offset = schema::MetaGraph::Pack(builder, &graph);
|
||||||
|
builder.Finish(offset);
|
||||||
|
schema::FinishMetaGraphBuffer(builder, offset);
|
||||||
|
int size = builder.GetSize();
|
||||||
|
auto content = builder.GetBufferPointer();
|
||||||
|
if (content == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::string filename = outputPath;
|
||||||
|
if (filename.length() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Invalid output path.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||||
|
filename = filename + ".ms";
|
||||||
|
}
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
if (access(filename.c_str(), F_OK) == 0) {
|
||||||
|
chmod(filename.c_str(), S_IWUSR);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
std::ofstream output(filename, std::ofstream::binary);
|
||||||
|
if (!output.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "Can not open output file: " << filename;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
output.write((const char *)content, size);
|
||||||
|
if (output.bad()) {
|
||||||
|
output.close();
|
||||||
|
MS_LOG(ERROR) << "Write output file : " << filename << " failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
output.close();
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
chmod(filename.c_str(), S_IRUSR);
|
||||||
|
#endif
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
schema::MetaGraphT *Storage::Load(const std::string &inputPath) {
|
||||||
|
size_t size = 0;
|
||||||
|
std::string filename = inputPath;
|
||||||
|
if (filename.length() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Invalid input path.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||||
|
filename = filename + ".ms";
|
||||||
|
}
|
||||||
|
auto buf = ReadFile(filename.c_str(), &size);
|
||||||
|
if (buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the file buffer is nullptr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Verifier verify((const uint8_t *)buf, size);
|
||||||
|
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||||
|
MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto graphDefT = schema::UnPackMetaGraph(buf);
|
||||||
|
return graphDefT.release();
|
||||||
|
}
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_COMMON_STORAGE_H
|
||||||
|
#define MINDSPORE_LITE_SRC_COMMON_STORAGE_H
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
#include "flatbuffers/flatbuffers.h"
|
||||||
|
#include "schema/inner/model_generated.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
class Storage {
|
||||||
|
public:
|
||||||
|
static int Save(const schema::MetaGraphT &graph, const std::string &outputPath);
|
||||||
|
|
||||||
|
static schema::MetaGraphT *Load(const std::string &inputPath);
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H
|
|
@ -30,16 +30,74 @@
|
||||||
#include "src/cxx_api/callback/callback_adapter.h"
|
#include "src/cxx_api/callback/callback_adapter.h"
|
||||||
#include "src/cxx_api/callback/callback_impl.h"
|
#include "src/cxx_api/callback/callback_impl.h"
|
||||||
#include "src/cxx_api/model/model_impl.h"
|
#include "src/cxx_api/model/model_impl.h"
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
|
#include "src/common/decrypt.h"
|
||||||
|
#include "src/common/file_utils.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
std::mutex g_impl_init_lock;
|
std::mutex g_impl_init_lock;
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
|
Status DecryptModel(const std::string &cropto_lib_path, const void *model_buf, size_t model_size, const Key &dec_key,
|
||||||
|
const std::string &dec_mode, std::unique_ptr<Byte[]> *decrypt_buffer, size_t *decrypt_len) {
|
||||||
|
if (model_buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "model_buf is nullptr.";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
*decrypt_len = 0;
|
||||||
|
*decrypt_buffer = lite::Decrypt(cropto_lib_path, decrypt_len, reinterpret_cast<const Byte *>(model_buf), model_size,
|
||||||
|
dec_key.key, dec_key.len, dec_mode);
|
||||||
|
if (*decrypt_buffer == nullptr || *decrypt_len == 0) {
|
||||||
|
MS_LOG(ERROR) << "Decrypt buffer failed";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||||
const std::vector<char> &dec_mode) {
|
const std::string &cropto_lib_path) {
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteFileError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dec_key.len > 0) {
|
||||||
|
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||||
|
size_t decrypt_len = 0;
|
||||||
|
Status ret = DecryptModel(cropto_lib_path, model_data, data_size, dec_key, dec_mode, &decrypt_buffer, &decrypt_len);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Build model failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
#else
|
||||||
|
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||||
|
return kLiteError;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
|
const std::shared_ptr<Context> &model_context) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteFileError;
|
return kLiteFileError;
|
||||||
|
@ -54,11 +112,59 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||||
const std::vector<char> &dec_mode) {
|
const std::vector<char> &cropto_lib_path) {
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteFileError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dec_key.len > 0) {
|
||||||
|
size_t model_size;
|
||||||
|
auto model_buf = lite::ReadFile(model_path.data(), &model_size);
|
||||||
|
if (model_buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Read model file failed";
|
||||||
|
return kLiteError;
|
||||||
|
}
|
||||||
|
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||||
|
size_t decrypt_len = 0;
|
||||||
|
Status ret = DecryptModel(CharToString(cropto_lib_path), model_buf, model_size, dec_key, dec_mode, &decrypt_buffer,
|
||||||
|
&decrypt_len);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||||
|
delete[] model_buf;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Build model failed.";
|
||||||
|
delete[] model_buf;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
delete[] model_buf;
|
||||||
|
} else {
|
||||||
|
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Build model failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
#else
|
||||||
|
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||||
|
return kLiteError;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||||
|
const std::shared_ptr<Context> &model_context) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteFileError;
|
return kLiteFileError;
|
||||||
|
@ -77,7 +183,7 @@ Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_conte
|
||||||
std::stringstream err_msg;
|
std::stringstream err_msg;
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteFileError;
|
return kLiteFileError;
|
||||||
|
@ -258,7 +364,7 @@ Status Model::LoadConfig(const std::vector<char> &config_path) {
|
||||||
return Status(kLiteFileError, "Illegal operation.");
|
return Status(kLiteFileError, "Illegal operation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return Status(kLiteFileError, "Fail to load config file.");
|
return Status(kLiteFileError, "Fail to load config file.");
|
||||||
|
@ -276,7 +382,7 @@ Status Model::UpdateConfig(const std::vector<char> §ion,
|
||||||
const std::pair<std::vector<char>, std::vector<char>> &config) {
|
const std::pair<std::vector<char>, std::vector<char>> &config) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::make_shared<ModelImpl>();
|
||||||
}
|
}
|
||||||
if (impl_ != nullptr) {
|
if (impl_ != nullptr) {
|
||||||
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
|
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
|
||||||
|
@ -388,5 +494,4 @@ float Model::GetLearningRate() {
|
||||||
}
|
}
|
||||||
return impl_->GetLearningRate();
|
return impl_->GetLearningRate();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "src/train/train_utils.h"
|
#include "src/train/train_utils.h"
|
||||||
#include "src/common/quant_utils.h"
|
#include "src/common/quant_utils.h"
|
||||||
#include "tools/common/meta_graph_serializer.h"
|
#include "src/common/storage.h"
|
||||||
#include "src/train/graph_fusion.h"
|
#include "src/train/graph_fusion.h"
|
||||||
#include "src/train/graph_dropout.h"
|
#include "src/train/graph_dropout.h"
|
||||||
#include "src/weight_decoder.h"
|
#include "src/weight_decoder.h"
|
||||||
|
@ -553,7 +553,7 @@ int TrainExport::ExportInit(const std::string model_name, std::string version) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int TrainExport::SaveToFile() { return MetaGraphSerializer::Save(*meta_graph_, file_name_); }
|
int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
|
||||||
|
|
||||||
bool TrainExport::IsInputTensor(const schema::TensorT &t) {
|
bool TrainExport::IsInputTensor(const schema::TensorT &t) {
|
||||||
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
|
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
844020
|
848116
|
||||||
|
|
|
@ -69,6 +69,7 @@ constexpr int kNumPrintMin = 5;
|
||||||
constexpr const char *DELIM_COLON = ":";
|
constexpr const char *DELIM_COLON = ":";
|
||||||
constexpr const char *DELIM_COMMA = ",";
|
constexpr const char *DELIM_COMMA = ",";
|
||||||
constexpr const char *DELIM_SLASH = "/";
|
constexpr const char *DELIM_SLASH = "/";
|
||||||
|
constexpr size_t kEncMaxLen = 16;
|
||||||
|
|
||||||
extern const std::unordered_map<int, std::string> kTypeIdMap;
|
extern const std::unordered_map<int, std::string> kTypeIdMap;
|
||||||
extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
|
extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
|
||||||
|
@ -139,6 +140,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
||||||
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1);
|
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1);
|
||||||
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
|
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
|
||||||
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
|
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
|
||||||
|
AddFlag(&BenchmarkFlags::decrypt_key_str_, "decryptKey",
|
||||||
|
"The key used to decrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
|
||||||
|
"length is 16.",
|
||||||
|
"");
|
||||||
|
AddFlag(&BenchmarkFlags::crypto_lib_path_, "cryptoLibPath", "Pass the crypto library path.", "");
|
||||||
AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
|
AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
|
||||||
false);
|
false);
|
||||||
AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict",
|
AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict",
|
||||||
|
@ -192,6 +198,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
||||||
std::string perf_event_ = "CYCLE";
|
std::string perf_event_ = "CYCLE";
|
||||||
bool dump_tensor_data_ = false;
|
bool dump_tensor_data_ = false;
|
||||||
bool print_tensor_data_ = false;
|
bool print_tensor_data_ = false;
|
||||||
|
std::string decrypt_key_str_;
|
||||||
|
std::string dec_mode_ = "AES-GCM";
|
||||||
|
std::string crypto_lib_path_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MS_API BenchmarkBase {
|
class MS_API BenchmarkBase {
|
||||||
|
|
|
@ -698,7 +698,6 @@ int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name,
|
||||||
*total_size += 1;
|
*total_size += 1;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor,
|
int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor,
|
||||||
float *total_cosine_distance, int *total_size) {
|
float *total_cosine_distance, int *total_size) {
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
|
@ -1044,6 +1043,33 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
int BenchmarkUnifiedApi::CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context,
|
||||||
|
const std::string &model_name) {
|
||||||
|
Key dec_key;
|
||||||
|
if (!flags_->decrypt_key_str_.empty()) {
|
||||||
|
dec_key.len = lite::Hex2ByteArray(flags_->decrypt_key_str_, dec_key.key, kEncMaxLen);
|
||||||
|
if (dec_key.len == 0) {
|
||||||
|
MS_LOG(ERROR) << "dec_key.len == 0";
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
flags_->decrypt_key_str_.clear();
|
||||||
|
}
|
||||||
|
Status ret;
|
||||||
|
if (flags_->crypto_lib_path_.empty()) {
|
||||||
|
ret = ms_model_.Build(flags_->model_file_, model_type, context);
|
||||||
|
} else {
|
||||||
|
ret =
|
||||||
|
ms_model_.Build(flags_->model_file_, model_type, context, dec_key, flags_->dec_mode_, flags_->crypto_lib_path_);
|
||||||
|
}
|
||||||
|
memset(dec_key.key, 0, kEncMaxLen);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
|
||||||
|
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int BenchmarkUnifiedApi::RunBenchmark() {
|
int BenchmarkUnifiedApi::RunBenchmark() {
|
||||||
auto start_prepare_time = GetTimeUs();
|
auto start_prepare_time = GetTimeUs();
|
||||||
|
|
||||||
|
@ -1098,19 +1124,17 @@ int BenchmarkUnifiedApi::RunBenchmark() {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
auto ret = ms_model_.Build(flags_->model_file_, model_type, context);
|
status = CompileGraph(model_type, context, model_name);
|
||||||
if (ret != kSuccess) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
|
MS_LOG(ERROR) << "Compile graph failed.";
|
||||||
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
|
return status;
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!flags_->resize_dims_.empty()) {
|
if (!flags_->resize_dims_.empty()) {
|
||||||
std::vector<std::vector<int64_t>> resize_dims;
|
std::vector<std::vector<int64_t>> resize_dims;
|
||||||
(void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims),
|
(void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims),
|
||||||
[&](auto &shapes) { return this->ConverterToInt64Vector<int>(shapes); });
|
[&](auto &shapes) { return this->ConverterToInt64Vector<int>(shapes); });
|
||||||
|
|
||||||
ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims);
|
auto ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Input tensor resize failed.";
|
MS_LOG(ERROR) << "Input tensor resize failed.";
|
||||||
std::cout << "Input tensor resize failed.";
|
std::cout << "Input tensor resize failed.";
|
||||||
|
|
|
@ -62,6 +62,8 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
|
||||||
float *total_cosine_distance, int *total_size);
|
float *total_cosine_distance, int *total_size);
|
||||||
void InitContext(const std::shared_ptr<mindspore::Context> &context);
|
void InitContext(const std::shared_ptr<mindspore::Context> &context);
|
||||||
|
|
||||||
|
int CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context, const std::string &model_name);
|
||||||
|
|
||||||
#ifdef ENABLE_OPENGL_TEXTURE
|
#ifdef ENABLE_OPENGL_TEXTURE
|
||||||
int GenerateGLTexture(std::map<std::string, GLuint> *inputGlTexture);
|
int GenerateGLTexture(std::map<std::string, GLuint> *inputGlTexture);
|
||||||
|
|
||||||
|
|
|
@ -206,7 +206,8 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT) {
|
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
|
||||||
|
const size_t key_len, const std::string &enc_mode) {
|
||||||
// serialize model
|
// serialize model
|
||||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||||
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
|
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
|
||||||
|
@ -214,7 +215,7 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
|
||||||
schema::FinishMetaGraphBuffer(builder, offset);
|
schema::FinishMetaGraphBuffer(builder, offset);
|
||||||
size_t size = builder.GetSize();
|
size_t size = builder.GetSize();
|
||||||
auto content = builder.GetBufferPointer();
|
auto content = builder.GetBufferPointer();
|
||||||
if (!SerializeModel(content, size)) {
|
if (!SerializeModel(content, size, key, key_len, enc_mode)) {
|
||||||
MS_LOG(ERROR) << "Serialize graph failed";
|
MS_LOG(ERROR) << "Serialize graph failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -238,7 +239,8 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path) {
|
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
|
||||||
|
const size_t key_len, const std::string &enc_mode) {
|
||||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||||
auto offset = schema::MetaGraph::Pack(builder, &graph);
|
auto offset = schema::MetaGraph::Pack(builder, &graph);
|
||||||
builder.Finish(offset);
|
builder.Finish(offset);
|
||||||
|
@ -255,7 +257,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (save_together) {
|
if (save_together) {
|
||||||
if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size)) {
|
if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size, key, key_len, enc_mode)) {
|
||||||
MS_LOG(ERROR) << "Serialize graph failed";
|
MS_LOG(ERROR) << "Serialize graph failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -264,7 +266,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
||||||
MS_LOG(ERROR) << "Serialize graph weight failed";
|
MS_LOG(ERROR) << "Serialize graph weight failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph)) {
|
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) {
|
||||||
MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
|
MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -283,14 +285,25 @@ MetaGraphSerializer::~MetaGraphSerializer() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MetaGraphSerializer::SerializeModel(const void *content, size_t size) {
|
bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
|
||||||
|
const std::string &enc_mode) {
|
||||||
MS_ASSERT(model_fs_ != nullptr);
|
MS_ASSERT(model_fs_ != nullptr);
|
||||||
if (size == 0 || content == nullptr) {
|
if (size == 0 || content == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
|
MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (key_len > 0) {
|
||||||
model_fs_->write((const char *)content, static_cast<int64_t>(size));
|
size_t encrypt_len;
|
||||||
|
auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast<const Byte *>(content), size, key, key_len, enc_mode);
|
||||||
|
if (encrypt_content == nullptr || encrypt_len == 0) {
|
||||||
|
MS_LOG(ERROR) << "Encrypt failed.";
|
||||||
|
model_fs_->close();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
model_fs_->write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
|
||||||
|
} else {
|
||||||
|
model_fs_->write((const char *)content, static_cast<int64_t>(size));
|
||||||
|
}
|
||||||
if (model_fs_->bad()) {
|
if (model_fs_->bad()) {
|
||||||
MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
|
MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -21,12 +21,14 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "flatbuffers/flatbuffers.h"
|
#include "flatbuffers/flatbuffers.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
|
#include "utils/crypto.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class MetaGraphSerializer {
|
class MetaGraphSerializer {
|
||||||
public:
|
public:
|
||||||
// save serialized fb model
|
// save serialized fb model
|
||||||
static int Save(const schema::MetaGraphT &graph, const std::string &output_path);
|
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {},
|
||||||
|
const size_t key_len = 0, const std::string &enc_mode = "");
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MetaGraphSerializer() = default;
|
MetaGraphSerializer() = default;
|
||||||
|
@ -41,9 +43,11 @@ class MetaGraphSerializer {
|
||||||
|
|
||||||
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
|
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
|
||||||
|
|
||||||
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT);
|
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len,
|
||||||
|
const std::string &enc_mode);
|
||||||
|
|
||||||
bool SerializeModel(const void *content, size_t size);
|
bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
|
||||||
|
const std::string &enc_mode);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t cur_offset_ = 0;
|
int64_t cur_offset_ = 0;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -126,5 +127,42 @@ bool ConvertDoubleVector(const std::string &str, std::vector<double> *value) {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) {
|
||||||
|
std::regex r("[0-9a-fA-F]+");
|
||||||
|
if (!std::regex_match(hex_str, r)) {
|
||||||
|
MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (hex_str.size() % 2 == 1) { // Mod 2 determines whether it is odd
|
||||||
|
MS_LOG(ERROR) << "the hexadecimal dec_key length must be even";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
size_t byte_len = hex_str.size() / 2; // Two hexadecimal characters represent a byte
|
||||||
|
if (byte_len > max_len) {
|
||||||
|
MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: " << max_len;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
constexpr int32_t a_val = 10; // The value of 'A' in hexadecimal is 10
|
||||||
|
constexpr size_t half_byte_offset = 4;
|
||||||
|
for (size_t i = 0; i < byte_len; ++i) {
|
||||||
|
size_t p = i * 2; // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters
|
||||||
|
if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
|
||||||
|
byte_array[i] = hex_str[p] - 'a' + a_val;
|
||||||
|
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||||
|
byte_array[i] = hex_str[p] - 'A' + a_val;
|
||||||
|
} else {
|
||||||
|
byte_array[i] = hex_str[p] - '0';
|
||||||
|
}
|
||||||
|
if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
|
||||||
|
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val);
|
||||||
|
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||||
|
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val);
|
||||||
|
} else {
|
||||||
|
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return byte_len;
|
||||||
|
}
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -40,6 +40,8 @@ bool ConvertDoubleNum(const std::string &str, double *value);
|
||||||
bool ConvertBool(std::string str, bool *value);
|
bool ConvertBool(std::string str, bool *value);
|
||||||
|
|
||||||
bool ConvertDoubleVector(const std::string &str, std::vector<double> *value);
|
bool ConvertDoubleVector(const std::string &str, std::vector<double> *value);
|
||||||
|
|
||||||
|
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_
|
#endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_
|
||||||
|
|
|
@ -343,4 +343,32 @@ int GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GenerateRandomData(mindspore::MSTensor *tensor) {
|
||||||
|
MS_ASSERT(tensor != nullptr);
|
||||||
|
auto input_data = tensor->MutableData();
|
||||||
|
if (input_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "MallocData for inTensor failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
int status = RET_ERROR;
|
||||||
|
if (static_cast<TypeId>(tensor->DataType()) == kObjectTypeString) {
|
||||||
|
MSTensor *input = MSTensor::StringsToTensor(tensor->Name(), {"you're the best."});
|
||||||
|
if (input == nullptr) {
|
||||||
|
std::cerr << "StringsToTensor failed" << std::endl;
|
||||||
|
MS_LOG(ERROR) << "StringsToTensor failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
*tensor = *input;
|
||||||
|
delete input;
|
||||||
|
} else {
|
||||||
|
status = GenerateRandomData(tensor->DataSize(), input_data, static_cast<int>(tensor->DataType()));
|
||||||
|
}
|
||||||
|
if (status != RET_OK) {
|
||||||
|
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
|
||||||
|
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -78,6 +78,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
|
||||||
|
|
||||||
int GenerateRandomData(mindspore::tensor::MSTensor *tensors);
|
int GenerateRandomData(mindspore::tensor::MSTensor *tensors);
|
||||||
|
|
||||||
|
int GenerateRandomData(mindspore::MSTensor *tensors);
|
||||||
|
|
||||||
int GenerateRandomData(size_t size, void *data, int data_type);
|
int GenerateRandomData(size_t size, void *data, int data_type);
|
||||||
|
|
||||||
template <typename T, typename Distribution>
|
template <typename T, typename Distribution>
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
add_definitions(-DPRIMITIVE_WRITEABLE)
|
add_definitions(-DPRIMITIVE_WRITEABLE)
|
||||||
add_definitions(-DUSE_GLOG)
|
add_definitions(-DUSE_GLOG)
|
||||||
set(USE_GLOG on)
|
set(USE_GLOG on)
|
||||||
|
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||||
|
add_compile_definitions(ENABLE_OPENSSL)
|
||||||
|
endif()
|
||||||
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
||||||
|
|
||||||
set(CCSRC_SRC
|
set(CCSRC_SRC
|
||||||
|
@ -13,8 +15,8 @@ set(CCSRC_SRC
|
||||||
include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
|
include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
|
||||||
|
|
||||||
if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL)
|
if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL)
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
@ -50,11 +52,10 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc
|
||||||
|
|
||||||
${SRC_DIR}/common/quant_utils.cc
|
${SRC_DIR}/common/quant_utils.cc
|
||||||
${SRC_DIR}/common/dynamic_library_loader.cc
|
${SRC_DIR}/common/dynamic_library_loader.cc
|
||||||
${SRC_DIR}/train/train_populate_parameter.cc
|
${SRC_DIR}/train/train_populate_parameter.cc
|
||||||
|
${SRC_DIR}/common/config_file.cc
|
||||||
../optimizer/*.cc
|
../optimizer/*.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,16 +77,20 @@ add_subdirectory(micro/coder)
|
||||||
|
|
||||||
if(MSLITE_ENABLE_ACL)
|
if(MSLITE_ENABLE_ACL)
|
||||||
set(MODE_ASCEND_ACL ON)
|
set(MODE_ASCEND_ACL ON)
|
||||||
|
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||||
include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
|
include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
|
||||||
add_subdirectory(adapter/acl)
|
add_subdirectory(adapter/acl)
|
||||||
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(API_SRC ${SRC_DIR}/cxx_api/context.cc)
|
file(GLOB CXX_API_SRCS
|
||||||
if(MSLITE_ENABLE_ACL)
|
${SRC_DIR}/cxx_api/*.cc
|
||||||
list(APPEND API_SRC ${SRC_DIR}/cxx_api/kernel.cc)
|
${SRC_DIR}/cxx_api/model/*.cc
|
||||||
endif()
|
${SRC_DIR}/cxx_api/graph/*.cc
|
||||||
|
${SRC_DIR}/cxx_api/tensor/*.cc)
|
||||||
|
|
||||||
set(LITE_SRC ${API_SRC}
|
set(LITE_SRC ${API_SRC}
|
||||||
|
${CXX_API_SRCS}
|
||||||
${SRC_DIR}/ops/ops_def.cc
|
${SRC_DIR}/ops/ops_def.cc
|
||||||
${SRC_DIR}/ops/ops_utils.cc
|
${SRC_DIR}/ops/ops_utils.cc
|
||||||
${SRC_DIR}/common/utils.cc
|
${SRC_DIR}/common/utils.cc
|
||||||
|
@ -97,6 +102,7 @@ set(LITE_SRC ${API_SRC}
|
||||||
${SRC_DIR}/common/log.cc
|
${SRC_DIR}/common/log.cc
|
||||||
${SRC_DIR}/common/prim_util.cc
|
${SRC_DIR}/common/prim_util.cc
|
||||||
${SRC_DIR}/common/tensor_util.cc
|
${SRC_DIR}/common/tensor_util.cc
|
||||||
|
${SRC_DIR}/common/decrypt.cc
|
||||||
${SRC_DIR}/runtime/allocator.cc
|
${SRC_DIR}/runtime/allocator.cc
|
||||||
${SRC_DIR}/runtime/inner_allocator.cc
|
${SRC_DIR}/runtime/inner_allocator.cc
|
||||||
${SRC_DIR}/runtime/runtime_allocator.cc
|
${SRC_DIR}/runtime/runtime_allocator.cc
|
||||||
|
|
|
@ -33,9 +33,15 @@
|
||||||
#include "tools/converter/import/mindspore_importer.h"
|
#include "tools/converter/import/mindspore_importer.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "tools/converter/micro/coder/coder.h"
|
#include "tools/converter/micro/coder/coder.h"
|
||||||
|
#include "src/common/prim_util.h"
|
||||||
|
#include "src/common/version_manager.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
#include "include/api/model.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr size_t kMaxNum1024 = 1024;
|
||||||
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
||||||
MS_ASSERT(converter_parameters != nullptr);
|
MS_ASSERT(converter_parameters != nullptr);
|
||||||
converter_parameters->fmk = flag.fmk;
|
converter_parameters->fmk = flag.fmk;
|
||||||
|
@ -178,6 +184,90 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter
|
||||||
return meta_graph;
|
return meta_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom_nodes) {
|
||||||
|
MS_CHECK_TRUE_MSG(meta_graph != nullptr && exist_custom_nodes != nullptr, RET_ERROR, "input params contain nullptr.");
|
||||||
|
flatbuffers::FlatBufferBuilder fbb(kMaxNum1024);
|
||||||
|
for (const auto &node : meta_graph->nodes) {
|
||||||
|
auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
|
||||||
|
if (prim == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "get primitive failed.";
|
||||||
|
fbb.Clear();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (IsCustomNode(prim, static_cast<int>(SCHEMA_CUR))) {
|
||||||
|
*exist_custom_nodes = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fbb.Clear();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr<converter::Flags> &flags) {
|
||||||
|
if (flags->trainModel) {
|
||||||
|
MS_LOG(WARNING) << "train model dont support pre-infer.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool exist_custom_nodes = false;
|
||||||
|
auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes);
|
||||||
|
if (check_ret == RET_ERROR) {
|
||||||
|
MS_LOG(ERROR) << "CheckExistCustomOps failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (exist_custom_nodes) {
|
||||||
|
MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
mindspore::Model model;
|
||||||
|
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||||
|
auto offset = schema::MetaGraph::Pack(builder, &meta_graph);
|
||||||
|
builder.Finish(offset);
|
||||||
|
schema::FinishMetaGraphBuffer(builder, offset);
|
||||||
|
int size = builder.GetSize();
|
||||||
|
auto content = builder.GetBufferPointer();
|
||||||
|
if (content == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto context = std::make_shared<mindspore::Context>();
|
||||||
|
if (context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New context failed while running ";
|
||||||
|
std::cerr << "New context failed while running " << std::endl;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||||
|
auto &device_list = context->MutableDeviceInfo();
|
||||||
|
device_list.push_back(device_info);
|
||||||
|
|
||||||
|
auto ret = model.Build(content, size, kMindIR, context);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Build error ";
|
||||||
|
std::cerr << "Build error " << std::endl;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
for (auto &tensor : model.GetInputs()) {
|
||||||
|
if (tensor.Shape().empty() || tensor.DataSize() <= 0 ||
|
||||||
|
std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
|
||||||
|
MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
auto status = GenerateRandomData(&tensor);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<MSTensor> outputs;
|
||||||
|
ret = model.Predict(model.GetInputs(), &outputs);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Inference error ";
|
||||||
|
std::cerr << "Inference error " << std::endl;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int RunConverter(int argc, const char **argv) {
|
int RunConverter(int argc, const char **argv) {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
auto flags = std::make_unique<converter::Flags>();
|
auto flags = std::make_unique<converter::Flags>();
|
||||||
|
@ -215,6 +305,18 @@ int RunConverter(int argc, const char **argv) {
|
||||||
// save graph to file
|
// save graph to file
|
||||||
meta_graph->version = Version();
|
meta_graph->version = Version();
|
||||||
|
|
||||||
|
if (flags->infer) {
|
||||||
|
status = PreInference(*meta_graph, flags);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
oss.clear();
|
||||||
|
oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status);
|
||||||
|
MS_LOG(ERROR) << oss.str();
|
||||||
|
std::cout << oss.str() << std::endl;
|
||||||
|
delete meta_graph;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (flags->microParam.enable_micro) {
|
if (flags->microParam.enable_micro) {
|
||||||
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode,
|
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode,
|
||||||
flags->microParam.target, flags->microParam.support_parallel,
|
flags->microParam.target, flags->microParam.support_parallel,
|
||||||
|
@ -228,7 +330,7 @@ int RunConverter(int argc, const char **argv) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile);
|
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
delete meta_graph;
|
delete meta_graph;
|
||||||
oss.clear();
|
oss.clear();
|
||||||
|
@ -238,7 +340,12 @@ int RunConverter(int argc, const char **argv) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// clear key
|
||||||
|
status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen);
|
||||||
|
if (status != EOK) {
|
||||||
|
MS_LOG(ERROR) << "memset failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
delete meta_graph;
|
delete meta_graph;
|
||||||
oss.clear();
|
oss.clear();
|
||||||
oss << "CONVERT RESULT SUCCESS:" << status;
|
oss << "CONVERT RESULT SUCCESS:" << status;
|
||||||
|
|
|
@ -83,6 +83,20 @@ Flags::Flags() {
|
||||||
"");
|
"");
|
||||||
AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
|
AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
|
||||||
"Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
|
"Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
|
AddFlag(&Flags::encryptionStr, "encryption",
|
||||||
|
"Whether to export the encryption model."
|
||||||
|
"true | false",
|
||||||
|
"true");
|
||||||
|
AddFlag(&Flags::encKeyStr, "encryptKey",
|
||||||
|
"The key used to encrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
|
||||||
|
"length is 16.",
|
||||||
|
"");
|
||||||
|
#endif
|
||||||
|
AddFlag(&Flags::inferStr, "infer",
|
||||||
|
"Whether to do pre-inference after convert."
|
||||||
|
"true | false",
|
||||||
|
"false");
|
||||||
}
|
}
|
||||||
|
|
||||||
int Flags::InitInputOutputDataType() {
|
int Flags::InitInputOutputDataType() {
|
||||||
|
@ -310,8 +324,56 @@ int Flags::InitConfigFile() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int Flags::Init(int argc, const char **argv) {
|
int Flags::InitSaveFP16() {
|
||||||
int ret;
|
if (saveFP16Str == "on") {
|
||||||
|
saveFP16 = true;
|
||||||
|
} else if (saveFP16Str == "off") {
|
||||||
|
saveFP16 = false;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Init save_fp16 failed." << std::endl;
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Flags::InitPreInference() {
|
||||||
|
if (this->inferStr == "true") {
|
||||||
|
this->infer = true;
|
||||||
|
} else if (this->inferStr == "false") {
|
||||||
|
this->infer = false;
|
||||||
|
} else {
|
||||||
|
std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl;
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Flags::InitEncrypt() {
|
||||||
|
if (this->encryptionStr == "true") {
|
||||||
|
this->encryption = true;
|
||||||
|
} else if (this->encryptionStr == "false") {
|
||||||
|
this->encryption = false;
|
||||||
|
} else {
|
||||||
|
std::cerr << "INPUT ILLEGAL: encryption must be true|false " << std::endl;
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
if (this->encryption) {
|
||||||
|
if (encKeyStr.empty()) {
|
||||||
|
MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false.";
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen);
|
||||||
|
if (keyLen != kEncMaxLen) {
|
||||||
|
MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters "
|
||||||
|
<< " and only support AES-GCM method and the key length is 16.";
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
encKeyStr.clear();
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Flags::PreInit(int argc, const char **argv) {
|
||||||
if (argc == 1) {
|
if (argc == 1) {
|
||||||
std::cout << this->Usage() << std::endl;
|
std::cout << this->Usage() << std::endl;
|
||||||
return lite::RET_SUCCESS_EXIT;
|
return lite::RET_SUCCESS_EXIT;
|
||||||
|
@ -353,19 +415,23 @@ int Flags::Init(int argc, const char **argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!this->configFile.empty()) {
|
if (!this->configFile.empty()) {
|
||||||
ret = InitConfigFile();
|
auto ret = InitConfigFile();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
std::cerr << "Init config file failed." << std::endl;
|
std::cerr << "Init config file failed." << std::endl;
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
if (saveFP16Str == "on") {
|
int Flags::Init(int argc, const char **argv) {
|
||||||
saveFP16 = true;
|
auto ret = PreInit(argc, argv);
|
||||||
} else if (saveFP16Str == "off") {
|
if (ret != RET_OK) {
|
||||||
saveFP16 = false;
|
return ret;
|
||||||
} else {
|
}
|
||||||
std::cerr << "Init save_fp16 failed." << std::endl;
|
ret = InitSaveFP16();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
std::cerr << "Init save fp16 failed." << std::endl;
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,8 +464,25 @@ int Flags::Init(int argc, const char **argv) {
|
||||||
std::cerr << "Init graph input format failed." << std::endl;
|
std::cerr << "Init graph input format failed." << std::endl;
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret = InitEncrypt();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
std::cerr << "Init encrypt failed." << std::endl;
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = InitPreInference();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
std::cerr << "Init pre inference failed." << std::endl;
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
Flags::~Flags() {
|
||||||
|
dec_key.clear();
|
||||||
|
encKeyStr.clear();
|
||||||
|
memset(encKey, 0, kEncMaxLen);
|
||||||
|
}
|
||||||
|
|
||||||
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
|
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
|
||||||
// device: [device0 device1] ---> {cpu, gpu}
|
// device: [device0 device1] ---> {cpu, gpu}
|
||||||
|
|
|
@ -40,6 +40,7 @@ constexpr auto kMaxSplitRatio = 10;
|
||||||
constexpr auto kComputeRate = "computeRate";
|
constexpr auto kComputeRate = "computeRate";
|
||||||
constexpr auto kSplitDevice0 = "device0";
|
constexpr auto kSplitDevice0 = "device0";
|
||||||
constexpr auto kSplitDevice1 = "device1";
|
constexpr auto kSplitDevice1 = "device1";
|
||||||
|
constexpr size_t kEncMaxLen = 16;
|
||||||
struct ParallelSplitConfig {
|
struct ParallelSplitConfig {
|
||||||
ParallelSplitType parallel_split_type_ = SplitNo;
|
ParallelSplitType parallel_split_type_ = SplitNo;
|
||||||
std::vector<int64_t> parallel_compute_rates_;
|
std::vector<int64_t> parallel_compute_rates_;
|
||||||
|
@ -50,7 +51,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
||||||
public:
|
public:
|
||||||
Flags();
|
Flags();
|
||||||
|
|
||||||
~Flags() override = default;
|
~Flags() override;
|
||||||
|
|
||||||
int InitInputOutputDataType();
|
int InitInputOutputDataType();
|
||||||
|
|
||||||
|
@ -66,8 +67,16 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
||||||
|
|
||||||
int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
|
int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
|
||||||
|
|
||||||
|
int InitEncrypt();
|
||||||
|
|
||||||
|
int InitPreInference();
|
||||||
|
|
||||||
|
int InitSaveFP16();
|
||||||
|
|
||||||
int Init(int argc, const char **argv);
|
int Init(int argc, const char **argv);
|
||||||
|
|
||||||
|
int PreInit(int argc, const char **argv);
|
||||||
|
|
||||||
std::string modelFile;
|
std::string modelFile;
|
||||||
std::string outputFile;
|
std::string outputFile;
|
||||||
std::string fmkIn;
|
std::string fmkIn;
|
||||||
|
@ -91,7 +100,19 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
||||||
std::string graphInputFormatStr;
|
std::string graphInputFormatStr;
|
||||||
std::string device;
|
std::string device;
|
||||||
mindspore::Format graphInputFormat = mindspore::NHWC;
|
mindspore::Format graphInputFormat = mindspore::NHWC;
|
||||||
bool enable_micro = false;
|
std::string encKeyStr;
|
||||||
|
std::string encMode = "AES-GCM";
|
||||||
|
std::string inferStr;
|
||||||
|
#ifdef ENABLE_OPENSSL
|
||||||
|
std::string encryptionStr = "true";
|
||||||
|
bool encryption = true;
|
||||||
|
#else
|
||||||
|
std::string encryptionStr = "false";
|
||||||
|
bool encryption = false;
|
||||||
|
#endif
|
||||||
|
bool infer = false;
|
||||||
|
unsigned char encKey[kEncMaxLen];
|
||||||
|
size_t keyLen = 0;
|
||||||
|
|
||||||
lite::quant::CommonQuantParam commonQuantParam;
|
lite::quant::CommonQuantParam commonQuantParam;
|
||||||
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
|
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
|
||||||
|
|
Binary file not shown.
Loading…
Reference in New Issue