forked from mindspore-Ecosystem/mindspore
add encryption to lite
This commit is contained in:
parent
c337e64241
commit
f670a635f0
|
@ -12,17 +12,66 @@ else()
|
|||
set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl)
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
||||
mindspore_add_pkg(openssl
|
||||
VER 1.1.1k
|
||||
LIBS ssl crypto
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
CONFIGURE_COMMAND ./config no-zlib no-shared
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||
)
|
||||
include_directories(${openssl_INC})
|
||||
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||
if(BUILD_LITE)
|
||||
if(PLATFORM_ARM64 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||
set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
|
||||
set(PATH
|
||||
${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
|
||||
${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
|
||||
$ENV{PATH})
|
||||
mindspore_add_pkg(openssl
|
||||
VER 1.1.1k
|
||||
LIBS ssl crypto
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
CONFIGURE_COMMAND ./Configure android-arm64 -D__ANDROID_API__=29 no-zlib
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||
)
|
||||
elseif(PLATFORM_ARM32 AND ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||
set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK})
|
||||
set(PATH
|
||||
${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin:
|
||||
${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:
|
||||
$ENV{PATH})
|
||||
mindspore_add_pkg(openssl
|
||||
VER 1.1.1k
|
||||
LIBS ssl crypto
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
CONFIGURE_COMMAND ./Configure android-arm -D__ANDROID_API__=29 no-zlib
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||
)
|
||||
elseif(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
||||
mindspore_add_pkg(openssl
|
||||
VER 1.1.1k
|
||||
LIBS ssl crypto
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
CONFIGURE_COMMAND ./config no-zlib no-shared
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||
)
|
||||
else()
|
||||
MESSAGE(FATAL_ERROR "openssl does not support compilation for the current environment.")
|
||||
endif()
|
||||
include_directories(${openssl_INC})
|
||||
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||
else()
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE)
|
||||
mindspore_add_pkg(openssl
|
||||
VER 1.1.1k
|
||||
LIBS ssl crypto
|
||||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
CONFIGURE_COMMAND ./config no-zlib no-shared
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch
|
||||
PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch
|
||||
)
|
||||
include_directories(${openssl_INC})
|
||||
add_library(mindspore::ssl ALIAS openssl::ssl)
|
||||
add_library(mindspore::crypto ALIAS openssl::crypto)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -173,7 +173,7 @@ class MS_API Model {
|
|||
/// \return Status of operation
|
||||
Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
|
||||
|
||||
/// \brief Obtains optimizer params tensors of the model.
|
||||
/// \brief Obtains optimizer params tensors of the model.
|
||||
///
|
||||
/// \return The vector that includes all params tensors.
|
||||
std::vector<MSTensor> GetOptimizerParams() const;
|
||||
|
@ -256,17 +256,14 @@ class MS_API Model {
|
|||
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model_data Define the buffer read from a model file.
|
||||
/// \param[in] size Define bytes number of model buffer.
|
||||
/// \param[in] data_size Define bytes number of model buffer.
|
||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||
/// ModelType::kMindIR is valid for Lite.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
///
|
||||
/// \return Status.
|
||||
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr);
|
||||
|
||||
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
|
@ -274,13 +271,40 @@ class MS_API Model {
|
|||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||
/// ModelType::kMindIR is valid for Lite.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
///
|
||||
/// \return Status.
|
||||
inline Status Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
Status Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr);
|
||||
|
||||
/// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model_data Define the buffer read from a model file.
|
||||
/// \param[in] data_size Define bytes number of model buffer.
|
||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||
/// ModelType::kMindIR is valid for Lite.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
|
||||
/// \param[in] cropto_lib_path Define the openssl library path.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||
const std::string &cropto_lib_path);
|
||||
|
||||
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model_path Define the model path.
|
||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||
/// ModelType::kMindIR is valid for Lite.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM.
|
||||
/// \param[in] cropto_lib_path Define the openssl library path.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
@ -291,11 +315,10 @@ class MS_API Model {
|
|||
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
||||
Status LoadConfig(const std::vector<char> &config_path);
|
||||
Status UpdateConfig(const std::vector<char> §ion, const std::pair<std::vector<char>, std::vector<char>> &config);
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
Status Build(const std::vector<char> &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context);
|
||||
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||
|
||||
const Key &dec_key, const std::string &dec_mode, const std::vector<char> &cropto_lib_path);
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
|
||||
|
@ -321,14 +344,15 @@ Status Model::UpdateConfig(const std::string §ion, const std::pair<std::stri
|
|||
return UpdateConfig(StringToChar(section), config_pair);
|
||||
}
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
|
||||
inline Status Model::Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||
const std::string &dec_mode, const std::string &cropto_lib_path) {
|
||||
return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path));
|
||||
}
|
||||
|
||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
|
||||
inline Status Model::Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
return Build(StringToChar(model_path), model_type, model_context);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -52,14 +52,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
|||
return impl_->Build();
|
||||
}
|
||||
|
||||
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||
const std::vector<char> &) {
|
||||
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||
const std::string &, const std::vector<char> &) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kMCFailed;
|
||||
}
|
||||
|
||||
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||
const std::vector<char> &) {
|
||||
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kMCFailed;
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work
|
|||
}
|
||||
|
||||
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
|
||||
bool is_encrypt) {
|
||||
int iv_len, bool is_encrypt) {
|
||||
constexpr int32_t key_length_16 = 16;
|
||||
constexpr int32_t key_length_24 = 24;
|
||||
constexpr int32_t key_length_32 = 32;
|
||||
|
@ -163,8 +163,35 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
|
|||
int32_t ret = 0;
|
||||
auto ctx = EVP_CIPHER_CTX_new();
|
||||
if (is_encrypt) {
|
||||
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
}
|
||||
|
||||
|
@ -183,7 +210,7 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i
|
|||
}
|
||||
|
||||
bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key,
|
||||
int32_t key_len, const std::string &enc_mode) {
|
||||
int32_t key_len, const std::string &enc_mode, unsigned char *tag) {
|
||||
size_t encrypt_data_buf_len = *encrypt_data_len;
|
||||
int32_t cipher_len = 0;
|
||||
int32_t iv_len = AES_BLOCK_SIZE;
|
||||
|
@ -201,7 +228,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
|||
return false;
|
||||
}
|
||||
|
||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), true);
|
||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()), true);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
|
@ -214,15 +241,19 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
|||
MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
int32_t flen = 0;
|
||||
ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
|
||||
if (ret_evp != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
cipher_len += flen;
|
||||
int32_t flen = 0;
|
||||
ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
|
||||
if (ret_evp != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
cipher_len += flen;
|
||||
|
||||
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, Byte16, tag) != 1) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
|
||||
size_t offset = 0;
|
||||
|
@ -266,7 +297,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto
|
|||
}
|
||||
|
||||
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
|
||||
int32_t key_len, const std::string &dec_mode) {
|
||||
int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||
|
@ -277,7 +308,7 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
|
|||
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
||||
return false;
|
||||
}
|
||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), false);
|
||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), iv.size(), false);
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
|
@ -288,15 +319,20 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data
|
|||
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
if (work_mode == "CBC") {
|
||||
int32_t mlen = 0;
|
||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
*plain_len += mlen;
|
||||
|
||||
if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t mlen = 0;
|
||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
*plain_len += mlen;
|
||||
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
@ -319,7 +355,9 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
|
|||
size_t block_enc_len = block_enc_buf.size();
|
||||
size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset);
|
||||
block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size);
|
||||
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode)) {
|
||||
unsigned char tag[Byte16];
|
||||
if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode,
|
||||
tag)) {
|
||||
MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -332,6 +370,13 @@ std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, siz
|
|||
}
|
||||
*encrypt_len += sizeof(int32_t);
|
||||
|
||||
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb
|
||||
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, tag, Byte16);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
|
||||
}
|
||||
*encrypt_len += Byte16;
|
||||
|
||||
capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN);
|
||||
ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len);
|
||||
if (ret != 0) {
|
||||
|
@ -371,6 +416,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
|
|||
MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned char tag[Byte16];
|
||||
fid.read(reinterpret_cast<char *>(tag), Byte16);
|
||||
|
||||
fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
|
||||
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||
if (block_size < 0) {
|
||||
|
@ -379,7 +428,7 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_
|
|||
}
|
||||
fid.read(block_buf.data(), static_cast<int64_t>(block_size));
|
||||
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||
static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode))) {
|
||||
static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -409,6 +458,10 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
|
|||
size_t offset = 0;
|
||||
*decrypt_len = 0;
|
||||
while (offset < data_size) {
|
||||
if (offset + sizeof(int32_t) > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||
offset += int_buf.size();
|
||||
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||
|
@ -416,27 +469,44 @@ std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, siz
|
|||
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned char tag[Byte16];
|
||||
if (offset + Byte16 > data_size) {
|
||||
MS_LOG(ERROR) << "buffer is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = memcpy_s(tag, Byte16, model_data + offset, Byte16);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
|
||||
}
|
||||
offset += Byte16;
|
||||
if (offset + sizeof(int32_t) > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||
offset += int_buf.size();
|
||||
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||
if (block_size < 0) {
|
||||
if (block_size <= 0) {
|
||||
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
|
||||
return nullptr;
|
||||
}
|
||||
if (offset + block_size > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
block_buf.assign(model_data + offset, model_data + offset + block_size);
|
||||
offset += block_buf.size();
|
||||
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode))) {
|
||||
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
return nullptr;
|
||||
}
|
||||
size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
|
||||
auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(),
|
||||
static_cast<size_t>(decrypt_block_len));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
|
||||
ret = memcpy_s(decrypt_data.get() + *decrypt_len, data_size, decrypt_block_buf.data(),
|
||||
static_cast<size_t>(decrypt_block_len));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s failed " << ret;
|
||||
}
|
||||
|
||||
*decrypt_len += static_cast<size_t>(decrypt_block_len);
|
||||
}
|
||||
return decrypt_data;
|
||||
|
|
|
@ -26,6 +26,7 @@ namespace mindspore {
|
|||
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||
constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info
|
||||
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||
constexpr size_t Byte16 = 16;
|
||||
|
||||
MS_CORE_API std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len,
|
||||
const Byte *key, size_t key_len, const std::string &enc_mode);
|
||||
|
|
|
@ -34,7 +34,7 @@ option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
|||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" off)
|
||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||
option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
|
||||
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
|
||||
|
@ -127,7 +127,11 @@ if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT})
|
|||
set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||
if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64) OR (PLATFORM_ARM AND ANDROID_NDK_TOOLCHAIN_INCLUDED))
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||
else()
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
|
||||
|
@ -227,7 +231,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
|||
endif()
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG off)
|
||||
set(MSLITE_ENABLE_RUNTIME_CONVERT off)
|
||||
#set for cross - compiling toolchain
|
||||
#set for cross - compiling toolchain
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||
|
@ -540,13 +544,15 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||
find_package(Patch)
|
||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
endif()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||
find_package(Patch)
|
||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
add_compile_definitions(ENABLE_OPENSSL)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_MINDRT)
|
||||
add_compile_definitions(ENABLE_MINDRT)
|
||||
endif()
|
||||
|
@ -590,7 +596,7 @@ if(NOT PLATFORM_ARM)
|
|||
endif()
|
||||
|
||||
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full"
|
||||
OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
|
||||
OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper")
|
||||
add_compile_definitions(ENABLE_ANDROID)
|
||||
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
|
||||
add_compile_definitions(ENABLE_MD_LITE_X86_64)
|
||||
|
@ -605,7 +611,7 @@ endif()
|
|||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops)
|
||||
if(ANDROID_NDK_TOOLCHAIN_INCLUDED)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||
|
|
|
@ -206,6 +206,7 @@ build_lite() {
|
|||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off"
|
||||
else
|
||||
checkndk
|
||||
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
|
||||
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on"
|
||||
|
@ -237,6 +238,7 @@ build_lite() {
|
|||
ARM64_COMPILE_CONVERTER=ON
|
||||
else
|
||||
checkndk
|
||||
export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH}
|
||||
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
|
||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}"
|
||||
LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv"
|
||||
|
|
|
@ -58,19 +58,19 @@ public class Model {
|
|||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param buffer model buffer.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
* @param buffer model buffer.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||
* @param cropto_lib_path define the openssl library path.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key,
|
||||
String dec_mode) {
|
||||
public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||
if (context == null || buffer == null || dec_key == null || dec_mode == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ public class Model {
|
|||
if (context == null || buffer == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "");
|
||||
modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", "");
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
|
@ -94,18 +94,19 @@ public class Model {
|
|||
/**
|
||||
* Build model.
|
||||
*
|
||||
* @param modelPath model path.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
* @param modelPath model path.
|
||||
* @param modelType model type.
|
||||
* @param context model build context.
|
||||
* @param dec_key define the key used to decrypt the ciphertext model. The key length is 16.
|
||||
* @param dec_mode define the decryption mode. Options: AES-GCM.
|
||||
* @param cropto_lib_path define the openssl library path.
|
||||
* @return model build status.
|
||||
*/
|
||||
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) {
|
||||
public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
|
||||
if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode);
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
|
@ -121,7 +122,7 @@ public class Model {
|
|||
if (context == null || modelPath == null) {
|
||||
return false;
|
||||
}
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "");
|
||||
modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", "");
|
||||
return modelPtr != 0;
|
||||
}
|
||||
|
||||
|
@ -256,8 +257,7 @@ public class Model {
|
|||
* @param outputTensorNames tensor name used for export inference graph.
|
||||
* @return Whether the export is successful.
|
||||
*/
|
||||
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer,
|
||||
List<String> outputTensorNames) {
|
||||
public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, List<String> outputTensorNames) {
|
||||
if (fileName == null) {
|
||||
return false;
|
||||
}
|
||||
|
@ -355,10 +355,11 @@ public class Model {
|
|||
|
||||
private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr);
|
||||
|
||||
private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod);
|
||||
private native long buildByPath(String modelPath, int modelType, long contextPtr,
|
||||
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||
|
||||
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key,
|
||||
String dec_mod);
|
||||
private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr,
|
||||
char[] dec_key, String dec_mod, String cropto_lib_path);
|
||||
|
||||
private native List<Long> getInputs(long modelPtr);
|
||||
|
||||
|
@ -380,8 +381,7 @@ public class Model {
|
|||
|
||||
private native boolean resize(long modelPtr, long[] inputs, int[][] dims);
|
||||
|
||||
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer,
|
||||
String[] outputTensorNames);
|
||||
private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, String[] outputTensorNames);
|
||||
|
||||
private native List<Long> getFeatureMaps(long modelPtr);
|
||||
|
||||
|
@ -389,6 +389,5 @@ public class Model {
|
|||
|
||||
private native boolean setLearningRate(long modelPtr, float learning_rate);
|
||||
|
||||
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate,
|
||||
float momentum);
|
||||
private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv
|
|||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz,
|
||||
jobject model_buffer, jint model_type,
|
||||
jlong context_ptr, jcharArray key_str,
|
||||
jstring dec_mod) {
|
||||
jstring dec_mod, jstring cropto_lib_path) {
|
||||
if (model_buffer == nullptr) {
|
||||
MS_LOGE("Buffer from java is nullptr");
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
|
@ -116,7 +116,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
}
|
||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod);
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
} else {
|
||||
status = model->Build(model_buf, buffer_len, c_model_type, context);
|
||||
}
|
||||
|
@ -130,7 +131,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv
|
|||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path,
|
||||
jint model_type, jlong context_ptr,
|
||||
jcharArray key_str, jstring dec_mod) {
|
||||
jcharArray key_str, jstring dec_mod,
|
||||
jstring cropto_lib_path) {
|
||||
auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE);
|
||||
mindspore::ModelType c_model_type;
|
||||
if (model_type >= static_cast<int>(mindspore::kMindIR) && model_type <= static_cast<int>(mindspore::kMindIR_Lite)) {
|
||||
|
@ -172,7 +174,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *
|
|||
}
|
||||
env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT);
|
||||
mindspore::Key dec_key{dec_key_data, key_len};
|
||||
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod);
|
||||
auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE);
|
||||
status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path);
|
||||
} else {
|
||||
status = model->Build(c_model_path, c_model_type, context);
|
||||
}
|
||||
|
|
|
@ -131,6 +131,14 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/decrypt.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
|
@ -272,8 +280,7 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
|
||||
${TOOLS_DIR}/common/storage.cc
|
||||
${TOOLS_DIR}/common/meta_graph_serializer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
|
||||
${TOOLS_DIR}/converter/optimizer.cc
|
||||
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
|
||||
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc
|
||||
|
|
|
@ -0,0 +1,323 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/common/decrypt.h"
|
||||
#ifdef ENABLE_OPENSSL
|
||||
#include <openssl/aes.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include "src/common/dynamic_library_loader.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#endif
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
#ifndef SECUREC_MEM_MAX_LEN
|
||||
#define SECUREC_MEM_MAX_LEN 0x7fffffffUL
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
#ifndef ENABLE_OPENSSL
|
||||
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *, const Byte *, const size_t, const Byte *,
|
||||
const size_t, const std::string &) {
|
||||
MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
|
||||
"when the OPENSSL compilation option is enabled.";
|
||||
return nullptr;
|
||||
}
|
||||
#else
|
||||
namespace {
|
||||
constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
|
||||
constexpr size_t Byte16 = 16; // Byte16
|
||||
constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number
|
||||
DynamicLibraryLoader loader;
|
||||
} // namespace
|
||||
int32_t ByteToInt(const Byte *byteArray, size_t length) {
|
||||
if (byteArray == nullptr) {
|
||||
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||
return -1;
|
||||
}
|
||||
if (length < sizeof(int32_t)) {
|
||||
MS_LOG(ERROR) << "Length of byteArray is " << length << ", less than sizeof(int32_t): 4.";
|
||||
return -1;
|
||||
}
|
||||
return *(reinterpret_cast<const int32_t *>(byteArray));
|
||||
}
|
||||
|
||||
bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector<Byte> *iv,
|
||||
std::vector<Byte> *cipher_data) {
|
||||
if (encrypt_data == nullptr || iv == nullptr || cipher_data == nullptr) {
|
||||
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||
return false;
|
||||
}
|
||||
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
|
||||
std::vector<Byte> int_buf(sizeof(int32_t));
|
||||
if (sizeof(int32_t) > encrypt_len) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return false;
|
||||
}
|
||||
int_buf.assign(encrypt_data, encrypt_data + sizeof(int32_t));
|
||||
auto iv_len = ByteToInt(int_buf.data(), int_buf.size());
|
||||
if (iv_len <= 0 || iv_len + sizeof(int32_t) + sizeof(int32_t) > encrypt_len) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return false;
|
||||
}
|
||||
int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t));
|
||||
auto cipher_len = ByteToInt(int_buf.data(), int_buf.size());
|
||||
if (((static_cast<size_t>(iv_len) + sizeof(int32_t) + static_cast<size_t>(cipher_len) + sizeof(int32_t)) !=
|
||||
encrypt_len)) {
|
||||
MS_LOG(ERROR) << "Failed to parse encrypt data.";
|
||||
return false;
|
||||
}
|
||||
|
||||
(*iv).assign(encrypt_data + sizeof(int32_t), encrypt_data + sizeof(int32_t) + iv_len);
|
||||
if (cipher_len <= 0 || sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len > encrypt_len) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return false;
|
||||
}
|
||||
(*cipher_data)
|
||||
.assign(encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t),
|
||||
encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) {
|
||||
if (alg_mode == nullptr || work_mode == nullptr) {
|
||||
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||
return false;
|
||||
}
|
||||
std::smatch results;
|
||||
std::regex re("([A-Z]{3})-([A-Z]{3})");
|
||||
if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) {
|
||||
MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
|
||||
return false;
|
||||
}
|
||||
const size_t index_1 = 1;
|
||||
const size_t index_2 = 2;
|
||||
*alg_mode = results[index_1];
|
||||
*work_mode = results[index_2];
|
||||
return true;
|
||||
}
|
||||
|
||||
EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
|
||||
int iv_len) {
|
||||
constexpr int32_t key_length_16 = 16;
|
||||
constexpr int32_t key_length_24 = 24;
|
||||
constexpr int32_t key_length_32 = 32;
|
||||
const EVP_CIPHER *(*funcPtr)() = nullptr;
|
||||
if (work_mode == "GCM") {
|
||||
switch (key_len) {
|
||||
case key_length_16:
|
||||
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_128_gcm");
|
||||
break;
|
||||
case key_length_24:
|
||||
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_192_gcm");
|
||||
break;
|
||||
case key_length_32:
|
||||
funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_256_gcm");
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t ret = 0;
|
||||
EVP_CIPHER_CTX *(*EVP_CIPHER_CTX_new)() = (EVP_CIPHER_CTX * (*)()) loader.GetFunc("EVP_CIPHER_CTX_new");
|
||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int (*EVP_DecryptInit_ex)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
|
||||
const unsigned char *) =
|
||||
(int (*)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *,
|
||||
const unsigned char *))loader.GetFunc("EVP_DecryptInit_ex");
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL);
|
||||
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
|
||||
(int (*)(EVP_CIPHER_CTX * ctx, int type, int arg, void *ptr)) loader.GetFunc("EVP_CIPHER_CTX_ctrl");
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptInit_ex failed";
|
||||
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
|
||||
int32_t key_len, const std::string &dec_mode, unsigned char *tag) {
|
||||
if (plain_data == nullptr || plain_len == nullptr || encrypt_data == nullptr || key == nullptr) {
|
||||
MS_LOG(ERROR) << "There is a null pointer in the input parameter.";
|
||||
return false;
|
||||
}
|
||||
std::string alg_mode;
|
||||
std::string work_mode;
|
||||
if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
|
||||
return false;
|
||||
}
|
||||
std::vector<Byte> iv;
|
||||
std::vector<Byte> cipher_data;
|
||||
if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
|
||||
return false;
|
||||
}
|
||||
auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast<int32_t>(iv.size()));
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
|
||||
return false;
|
||||
}
|
||||
int (*EVP_DecryptUpdate)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int) =
|
||||
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int))loader.GetFunc("EVP_DecryptUpdate");
|
||||
auto ret =
|
||||
EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast<int32_t>(cipher_data.size()));
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) =
|
||||
(int (*)(EVP_CIPHER_CTX *, int, int, void *))loader.GetFunc("EVP_CIPHER_CTX_ctrl");
|
||||
if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t mlen = 0;
|
||||
int (*EVP_DecryptFinal_ex)(EVP_CIPHER_CTX *, unsigned char *, int *) =
|
||||
(int (*)(EVP_CIPHER_CTX *, unsigned char *, int *))loader.GetFunc("EVP_DecryptFinal_ex");
|
||||
ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
|
||||
return false;
|
||||
}
|
||||
*plain_len += mlen;
|
||||
|
||||
void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free");
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
iv.assign(iv.size(), 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
|
||||
const size_t data_size, const Byte *key, const size_t key_len,
|
||||
const std::string &dec_mode) {
|
||||
if (model_data == nullptr) {
|
||||
MS_LOG(ERROR) << "model_data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (key == nullptr) {
|
||||
MS_LOG(ERROR) << "key is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (decrypt_len == nullptr) {
|
||||
MS_LOG(ERROR) << "decrypt_len is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = loader.Open(lib_path);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "loader open failed.";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<char> block_buf;
|
||||
std::vector<char> int_buf(sizeof(int32_t));
|
||||
std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE);
|
||||
auto decrypt_data = std::make_unique<Byte[]>(data_size);
|
||||
int32_t decrypt_block_len;
|
||||
|
||||
size_t offset = 0;
|
||||
*decrypt_len = 0;
|
||||
if (dec_mode != "AES-GCM") {
|
||||
MS_LOG(ERROR) << "dec_mode only support AES-GCM.";
|
||||
return nullptr;
|
||||
}
|
||||
if (key_len != Byte16) {
|
||||
MS_LOG(ERROR) << "key_len only support 16.";
|
||||
return nullptr;
|
||||
}
|
||||
while (offset < data_size) {
|
||||
if (offset + sizeof(int32_t) > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||
offset += int_buf.size();
|
||||
auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||
if (static_cast<unsigned int>(cipher_flag) != MAGIC_NUM) {
|
||||
MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
|
||||
return nullptr;
|
||||
}
|
||||
unsigned char tag[Byte16];
|
||||
if (offset + Byte16 > data_size) {
|
||||
MS_LOG(ERROR) << "buffer is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(tag, model_data + offset, Byte16);
|
||||
offset += Byte16;
|
||||
if (offset + sizeof(int32_t) > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
|
||||
offset += int_buf.size();
|
||||
auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
|
||||
if (block_size <= 0) {
|
||||
MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
|
||||
return nullptr;
|
||||
}
|
||||
if (offset + block_size > data_size) {
|
||||
MS_LOG(ERROR) << "assign len is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
block_buf.assign(model_data + offset, model_data + offset + block_size);
|
||||
offset += block_buf.size();
|
||||
if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
|
||||
block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode, tag))) {
|
||||
MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), static_cast<size_t>(decrypt_block_len));
|
||||
|
||||
*decrypt_len += static_cast<size_t>(decrypt_block_len);
|
||||
}
|
||||
ret = loader.Close();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "loader close failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return decrypt_data;
|
||||
}
|
||||
#endif
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_UTILS_DECRYPT_H_
|
||||
#define MINDSPORE_CORE_UTILS_DECRYPT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
typedef unsigned char Byte;
|
||||
namespace mindspore::lite {
|
||||
std::unique_ptr<Byte[]> Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data,
|
||||
const size_t data_size, const Byte *key, const size_t key_len,
|
||||
const std::string &dec_mode);
|
||||
} // namespace mindspore::lite
|
||||
#endif
|
|
@ -30,10 +30,13 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
int DynamicLibraryLoader::Open(const std::string &lib_path) {
|
||||
if (handler_ != nullptr) {
|
||||
return RET_ERROR;
|
||||
return RET_OK;
|
||||
}
|
||||
std::string real_path = RealPath(lib_path.c_str());
|
||||
|
||||
if (real_path.empty()) {
|
||||
MS_LOG(ERROR) << "real_path is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#ifndef _WIN32
|
||||
#ifndef ENABLE_ARM
|
||||
handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND);
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/common/storage.h"
|
||||
#include <sys/stat.h>
|
||||
#ifndef _MSC_VER
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kMaxNum1024 = 1024;
|
||||
}
|
||||
int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) {
|
||||
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, &graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
int size = builder.GetSize();
|
||||
auto content = builder.GetBufferPointer();
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::string filename = outputPath;
|
||||
if (filename.length() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid output path.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||
filename = filename + ".ms";
|
||||
}
|
||||
#ifndef _MSC_VER
|
||||
if (access(filename.c_str(), F_OK) == 0) {
|
||||
chmod(filename.c_str(), S_IWUSR);
|
||||
}
|
||||
#endif
|
||||
std::ofstream output(filename, std::ofstream::binary);
|
||||
if (!output.is_open()) {
|
||||
MS_LOG(ERROR) << "Can not open output file: " << filename;
|
||||
return RET_ERROR;
|
||||
}
|
||||
output.write((const char *)content, size);
|
||||
if (output.bad()) {
|
||||
output.close();
|
||||
MS_LOG(ERROR) << "Write output file : " << filename << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output.close();
|
||||
#ifndef _MSC_VER
|
||||
chmod(filename.c_str(), S_IRUSR);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Storage::Load(const std::string &inputPath) {
|
||||
size_t size = 0;
|
||||
std::string filename = inputPath;
|
||||
if (filename.length() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid input path.";
|
||||
return nullptr;
|
||||
}
|
||||
if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
|
||||
filename = filename + ".ms";
|
||||
}
|
||||
auto buf = ReadFile(filename.c_str(), &size);
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "the file buffer is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verify((const uint8_t *)buf, size);
|
||||
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||
MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto graphDefT = schema::UnPackMetaGraph(buf);
|
||||
return graphDefT.release();
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_COMMON_STORAGE_H
|
||||
#define MINDSPORE_LITE_SRC_COMMON_STORAGE_H
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include "include/errorcode.h"
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Storage {
|
||||
public:
|
||||
static int Save(const schema::MetaGraphT &graph, const std::string &outputPath);
|
||||
|
||||
static schema::MetaGraphT *Load(const std::string &inputPath);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H
|
|
@ -30,16 +30,74 @@
|
|||
#include "src/cxx_api/callback/callback_adapter.h"
|
||||
#include "src/cxx_api/callback/callback_impl.h"
|
||||
#include "src/cxx_api/model/model_impl.h"
|
||||
#ifdef ENABLE_OPENSSL
|
||||
#include "src/common/decrypt.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
std::mutex g_impl_init_lock;
|
||||
#ifdef ENABLE_OPENSSL
|
||||
Status DecryptModel(const std::string &cropto_lib_path, const void *model_buf, size_t model_size, const Key &dec_key,
|
||||
const std::string &dec_mode, std::unique_ptr<Byte[]> *decrypt_buffer, size_t *decrypt_len) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "model_buf is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
*decrypt_len = 0;
|
||||
*decrypt_buffer = lite::Decrypt(cropto_lib_path, decrypt_len, reinterpret_cast<const Byte *>(model_buf), model_size,
|
||||
dec_key.key, dec_key.len, dec_mode);
|
||||
if (*decrypt_buffer == nullptr || *decrypt_len == 0) {
|
||||
MS_LOG(ERROR) << "Decrypt buffer failed";
|
||||
return kLiteError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||
const std::string &cropto_lib_path) {
|
||||
#ifdef ENABLE_OPENSSL
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
if (dec_key.len > 0) {
|
||||
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||
size_t decrypt_len = 0;
|
||||
Status ret = DecryptModel(cropto_lib_path, model_data, data_size, dec_key, dec_mode, &decrypt_buffer, &decrypt_len);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
#else
|
||||
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||
return kLiteError;
|
||||
#endif
|
||||
}
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
|
@ -54,11 +112,59 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
|||
}
|
||||
|
||||
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||
const std::vector<char> &dec_mode) {
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
|
||||
const std::vector<char> &cropto_lib_path) {
|
||||
#ifdef ENABLE_OPENSSL
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
}
|
||||
}
|
||||
if (dec_key.len > 0) {
|
||||
size_t model_size;
|
||||
auto model_buf = lite::ReadFile(model_path.data(), &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return kLiteError;
|
||||
}
|
||||
std::unique_ptr<Byte[]> decrypt_buffer;
|
||||
size_t decrypt_len = 0;
|
||||
Status ret = DecryptModel(CharToString(cropto_lib_path), model_buf, model_size, dec_key, dec_mode, &decrypt_buffer,
|
||||
&decrypt_len);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Decrypt model failed.";
|
||||
delete[] model_buf;
|
||||
return ret;
|
||||
}
|
||||
ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
delete[] model_buf;
|
||||
return ret;
|
||||
}
|
||||
delete[] model_buf;
|
||||
} else {
|
||||
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build model failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
#else
|
||||
MS_LOG(ERROR) << "The lib is not support Decrypt Model.";
|
||||
return kLiteError;
|
||||
#endif
|
||||
}
|
||||
|
||||
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
|
@ -77,7 +183,7 @@ Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_conte
|
|||
std::stringstream err_msg;
|
||||
if (impl_ == nullptr) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteFileError;
|
||||
|
@ -258,7 +364,7 @@ Status Model::LoadConfig(const std::vector<char> &config_path) {
|
|||
return Status(kLiteFileError, "Illegal operation.");
|
||||
}
|
||||
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return Status(kLiteFileError, "Fail to load config file.");
|
||||
|
@ -276,7 +382,7 @@ Status Model::UpdateConfig(const std::vector<char> §ion,
|
|||
const std::pair<std::vector<char>, std::vector<char>> &config) {
|
||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||
if (impl_ == nullptr) {
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
impl_ = std::make_shared<ModelImpl>();
|
||||
}
|
||||
if (impl_ != nullptr) {
|
||||
return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)});
|
||||
|
@ -388,5 +494,4 @@ float Model::GetLearningRate() {
|
|||
}
|
||||
return impl_->GetLearningRate();
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "src/train/train_utils.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "tools/common/meta_graph_serializer.h"
|
||||
#include "src/common/storage.h"
|
||||
#include "src/train/graph_fusion.h"
|
||||
#include "src/train/graph_dropout.h"
|
||||
#include "src/weight_decoder.h"
|
||||
|
@ -553,7 +553,7 @@ int TrainExport::ExportInit(const std::string model_name, std::string version) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TrainExport::SaveToFile() { return MetaGraphSerializer::Save(*meta_graph_, file_name_); }
|
||||
int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); }
|
||||
|
||||
bool TrainExport::IsInputTensor(const schema::TensorT &t) {
|
||||
int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
|
||||
|
|
|
@ -1 +1 @@
|
|||
844020
|
||||
848116
|
||||
|
|
|
@ -69,6 +69,7 @@ constexpr int kNumPrintMin = 5;
|
|||
constexpr const char *DELIM_COLON = ":";
|
||||
constexpr const char *DELIM_COMMA = ",";
|
||||
constexpr const char *DELIM_SLASH = "/";
|
||||
constexpr size_t kEncMaxLen = 16;
|
||||
|
||||
extern const std::unordered_map<int, std::string> kTypeIdMap;
|
||||
extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
|
||||
|
@ -139,6 +140,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1);
|
||||
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
|
||||
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
|
||||
AddFlag(&BenchmarkFlags::decrypt_key_str_, "decryptKey",
|
||||
"The key used to decrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
|
||||
"length is 16.",
|
||||
"");
|
||||
AddFlag(&BenchmarkFlags::crypto_lib_path_, "cryptoLibPath", "Pass the crypto library path.", "");
|
||||
AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
|
||||
false);
|
||||
AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict",
|
||||
|
@ -192,6 +198,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
std::string perf_event_ = "CYCLE";
|
||||
bool dump_tensor_data_ = false;
|
||||
bool print_tensor_data_ = false;
|
||||
std::string decrypt_key_str_;
|
||||
std::string dec_mode_ = "AES-GCM";
|
||||
std::string crypto_lib_path_;
|
||||
};
|
||||
|
||||
class MS_API BenchmarkBase {
|
||||
|
|
|
@ -698,7 +698,6 @@ int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name,
|
|||
*total_size += 1;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor,
|
||||
float *total_cosine_distance, int *total_size) {
|
||||
if (tensor == nullptr) {
|
||||
|
@ -1044,6 +1043,33 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
|
|||
}
|
||||
#endif
|
||||
|
||||
int BenchmarkUnifiedApi::CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context,
|
||||
const std::string &model_name) {
|
||||
Key dec_key;
|
||||
if (!flags_->decrypt_key_str_.empty()) {
|
||||
dec_key.len = lite::Hex2ByteArray(flags_->decrypt_key_str_, dec_key.key, kEncMaxLen);
|
||||
if (dec_key.len == 0) {
|
||||
MS_LOG(ERROR) << "dec_key.len == 0";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
flags_->decrypt_key_str_.clear();
|
||||
}
|
||||
Status ret;
|
||||
if (flags_->crypto_lib_path_.empty()) {
|
||||
ret = ms_model_.Build(flags_->model_file_, model_type, context);
|
||||
} else {
|
||||
ret =
|
||||
ms_model_.Build(flags_->model_file_, model_type, context, dec_key, flags_->dec_mode_, flags_->crypto_lib_path_);
|
||||
}
|
||||
memset(dec_key.key, 0, kEncMaxLen);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BenchmarkUnifiedApi::RunBenchmark() {
|
||||
auto start_prepare_time = GetTimeUs();
|
||||
|
||||
|
@ -1098,19 +1124,17 @@ int BenchmarkUnifiedApi::RunBenchmark() {
|
|||
}
|
||||
#endif
|
||||
|
||||
auto ret = ms_model_.Build(flags_->model_file_, model_type, context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
return RET_ERROR;
|
||||
status = CompileGraph(model_type, context, model_name);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Compile graph failed.";
|
||||
return status;
|
||||
}
|
||||
|
||||
if (!flags_->resize_dims_.empty()) {
|
||||
std::vector<std::vector<int64_t>> resize_dims;
|
||||
(void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims),
|
||||
[&](auto &shapes) { return this->ConverterToInt64Vector<int>(shapes); });
|
||||
|
||||
ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims);
|
||||
auto ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Input tensor resize failed.";
|
||||
std::cout << "Input tensor resize failed.";
|
||||
|
|
|
@ -62,6 +62,8 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
|
|||
float *total_cosine_distance, int *total_size);
|
||||
void InitContext(const std::shared_ptr<mindspore::Context> &context);
|
||||
|
||||
int CompileGraph(ModelType model_type, const std::shared_ptr<Context> &context, const std::string &model_name);
|
||||
|
||||
#ifdef ENABLE_OPENGL_TEXTURE
|
||||
int GenerateGLTexture(std::map<std::string, GLuint> *inputGlTexture);
|
||||
|
||||
|
|
|
@ -206,7 +206,8 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT) {
|
||||
bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
// serialize model
|
||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||
auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
|
||||
|
@ -214,7 +215,7 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
|
|||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
auto content = builder.GetBufferPointer();
|
||||
if (!SerializeModel(content, size)) {
|
||||
if (!SerializeModel(content, size, key, key_len, enc_mode)) {
|
||||
MS_LOG(ERROR) << "Serialize graph failed";
|
||||
return false;
|
||||
}
|
||||
|
@ -238,7 +239,8 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT
|
|||
return true;
|
||||
}
|
||||
|
||||
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path) {
|
||||
int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
|
||||
const size_t key_len, const std::string &enc_mode) {
|
||||
flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
|
||||
auto offset = schema::MetaGraph::Pack(builder, &graph);
|
||||
builder.Finish(offset);
|
||||
|
@ -255,7 +257,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (save_together) {
|
||||
if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size)) {
|
||||
if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size, key, key_len, enc_mode)) {
|
||||
MS_LOG(ERROR) << "Serialize graph failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -264,7 +266,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
|||
MS_LOG(ERROR) << "Serialize graph weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph)) {
|
||||
if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) {
|
||||
MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -283,14 +285,25 @@ MetaGraphSerializer::~MetaGraphSerializer() {
|
|||
}
|
||||
}
|
||||
|
||||
bool MetaGraphSerializer::SerializeModel(const void *content, size_t size) {
|
||||
bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
|
||||
const std::string &enc_mode) {
|
||||
MS_ASSERT(model_fs_ != nullptr);
|
||||
if (size == 0 || content == nullptr) {
|
||||
MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
|
||||
return false;
|
||||
}
|
||||
|
||||
model_fs_->write((const char *)content, static_cast<int64_t>(size));
|
||||
if (key_len > 0) {
|
||||
size_t encrypt_len;
|
||||
auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast<const Byte *>(content), size, key, key_len, enc_mode);
|
||||
if (encrypt_content == nullptr || encrypt_len == 0) {
|
||||
MS_LOG(ERROR) << "Encrypt failed.";
|
||||
model_fs_->close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
model_fs_->write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
|
||||
} else {
|
||||
model_fs_->write((const char *)content, static_cast<int64_t>(size));
|
||||
}
|
||||
if (model_fs_->bad()) {
|
||||
MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -21,12 +21,14 @@
|
|||
#include <string>
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "utils/crypto.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MetaGraphSerializer {
|
||||
public:
|
||||
// save serialized fb model
|
||||
static int Save(const schema::MetaGraphT &graph, const std::string &output_path);
|
||||
static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {},
|
||||
const size_t key_len = 0, const std::string &enc_mode = "");
|
||||
|
||||
private:
|
||||
MetaGraphSerializer() = default;
|
||||
|
@ -41,9 +43,11 @@ class MetaGraphSerializer {
|
|||
|
||||
bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph);
|
||||
|
||||
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT);
|
||||
bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len,
|
||||
const std::string &enc_mode);
|
||||
|
||||
bool SerializeModel(const void *content, size_t size);
|
||||
bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
|
||||
const std::string &enc_mode);
|
||||
|
||||
private:
|
||||
int64_t cur_offset_ = 0;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -126,5 +127,42 @@ bool ConvertDoubleVector(const std::string &str, std::vector<double> *value) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) {
|
||||
std::regex r("[0-9a-fA-F]+");
|
||||
if (!std::regex_match(hex_str, r)) {
|
||||
MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]";
|
||||
return 0;
|
||||
}
|
||||
if (hex_str.size() % 2 == 1) { // Mod 2 determines whether it is odd
|
||||
MS_LOG(ERROR) << "the hexadecimal dec_key length must be even";
|
||||
return 0;
|
||||
}
|
||||
size_t byte_len = hex_str.size() / 2; // Two hexadecimal characters represent a byte
|
||||
if (byte_len > max_len) {
|
||||
MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: " << max_len;
|
||||
return 0;
|
||||
}
|
||||
constexpr int32_t a_val = 10; // The value of 'A' in hexadecimal is 10
|
||||
constexpr size_t half_byte_offset = 4;
|
||||
for (size_t i = 0; i < byte_len; ++i) {
|
||||
size_t p = i * 2; // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters
|
||||
if (hex_str[p] >= 'a' && hex_str[p] <= 'f') {
|
||||
byte_array[i] = hex_str[p] - 'a' + a_val;
|
||||
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||
byte_array[i] = hex_str[p] - 'A' + a_val;
|
||||
} else {
|
||||
byte_array[i] = hex_str[p] - '0';
|
||||
}
|
||||
if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') {
|
||||
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val);
|
||||
} else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') {
|
||||
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val);
|
||||
} else {
|
||||
byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0');
|
||||
}
|
||||
}
|
||||
return byte_len;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,8 @@ bool ConvertDoubleNum(const std::string &str, double *value);
|
|||
bool ConvertBool(std::string str, bool *value);
|
||||
|
||||
bool ConvertDoubleVector(const std::string &str, std::vector<double> *value);
|
||||
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_
|
||||
|
|
|
@ -343,4 +343,32 @@ int GenerateRandomData(mindspore::tensor::MSTensor *tensor) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GenerateRandomData(mindspore::MSTensor *tensor) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
auto input_data = tensor->MutableData();
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "MallocData for inTensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int status = RET_ERROR;
|
||||
if (static_cast<TypeId>(tensor->DataType()) == kObjectTypeString) {
|
||||
MSTensor *input = MSTensor::StringsToTensor(tensor->Name(), {"you're the best."});
|
||||
if (input == nullptr) {
|
||||
std::cerr << "StringsToTensor failed" << std::endl;
|
||||
MS_LOG(ERROR) << "StringsToTensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*tensor = *input;
|
||||
delete input;
|
||||
} else {
|
||||
status = GenerateRandomData(tensor->DataSize(), input_data, static_cast<int>(tensor->DataType()));
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
|
||||
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -78,6 +78,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
|
|||
|
||||
int GenerateRandomData(mindspore::tensor::MSTensor *tensors);
|
||||
|
||||
int GenerateRandomData(mindspore::MSTensor *tensors);
|
||||
|
||||
int GenerateRandomData(size_t size, void *data, int data_type);
|
||||
|
||||
template <typename T, typename Distribution>
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
add_definitions(-DPRIMITIVE_WRITEABLE)
|
||||
add_definitions(-DUSE_GLOG)
|
||||
set(USE_GLOG on)
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||
add_compile_definitions(ENABLE_OPENSSL)
|
||||
endif()
|
||||
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
||||
|
||||
set(CCSRC_SRC
|
||||
|
@ -13,8 +15,8 @@ set(CCSRC_SRC
|
|||
include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
|
||||
|
||||
if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
@ -50,11 +52,10 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc
|
||||
|
||||
${SRC_DIR}/common/quant_utils.cc
|
||||
${SRC_DIR}/common/dynamic_library_loader.cc
|
||||
${SRC_DIR}/train/train_populate_parameter.cc
|
||||
|
||||
${SRC_DIR}/common/config_file.cc
|
||||
../optimizer/*.cc
|
||||
)
|
||||
|
||||
|
@ -76,16 +77,20 @@ add_subdirectory(micro/coder)
|
|||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
set(MODE_ASCEND_ACL ON)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
|
||||
add_subdirectory(adapter/acl)
|
||||
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
endif()
|
||||
|
||||
set(API_SRC ${SRC_DIR}/cxx_api/context.cc)
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
list(APPEND API_SRC ${SRC_DIR}/cxx_api/kernel.cc)
|
||||
endif()
|
||||
file(GLOB CXX_API_SRCS
|
||||
${SRC_DIR}/cxx_api/*.cc
|
||||
${SRC_DIR}/cxx_api/model/*.cc
|
||||
${SRC_DIR}/cxx_api/graph/*.cc
|
||||
${SRC_DIR}/cxx_api/tensor/*.cc)
|
||||
|
||||
set(LITE_SRC ${API_SRC}
|
||||
${CXX_API_SRCS}
|
||||
${SRC_DIR}/ops/ops_def.cc
|
||||
${SRC_DIR}/ops/ops_utils.cc
|
||||
${SRC_DIR}/common/utils.cc
|
||||
|
@ -97,6 +102,7 @@ set(LITE_SRC ${API_SRC}
|
|||
${SRC_DIR}/common/log.cc
|
||||
${SRC_DIR}/common/prim_util.cc
|
||||
${SRC_DIR}/common/tensor_util.cc
|
||||
${SRC_DIR}/common/decrypt.cc
|
||||
${SRC_DIR}/runtime/allocator.cc
|
||||
${SRC_DIR}/runtime/inner_allocator.cc
|
||||
${SRC_DIR}/runtime/runtime_allocator.cc
|
||||
|
|
|
@ -33,9 +33,15 @@
|
|||
#include "tools/converter/import/mindspore_importer.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/converter/micro/coder/coder.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "include/api/model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kMaxNum1024 = 1024;
|
||||
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
||||
MS_ASSERT(converter_parameters != nullptr);
|
||||
converter_parameters->fmk = flag.fmk;
|
||||
|
@ -178,6 +184,90 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter
|
|||
return meta_graph;
|
||||
}
|
||||
|
||||
int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom_nodes) {
|
||||
MS_CHECK_TRUE_MSG(meta_graph != nullptr && exist_custom_nodes != nullptr, RET_ERROR, "input params contain nullptr.");
|
||||
flatbuffers::FlatBufferBuilder fbb(kMaxNum1024);
|
||||
for (const auto &node : meta_graph->nodes) {
|
||||
auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "get primitive failed.";
|
||||
fbb.Clear();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (IsCustomNode(prim, static_cast<int>(SCHEMA_CUR))) {
|
||||
*exist_custom_nodes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
fbb.Clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr<converter::Flags> &flags) {
|
||||
if (flags->trainModel) {
|
||||
MS_LOG(WARNING) << "train model dont support pre-infer.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool exist_custom_nodes = false;
|
||||
auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes);
|
||||
if (check_ret == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "CheckExistCustomOps failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (exist_custom_nodes) {
|
||||
MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer.";
|
||||
return RET_OK;
|
||||
}
|
||||
mindspore::Model model;
|
||||
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, &meta_graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
int size = builder.GetSize();
|
||||
auto content = builder.GetBufferPointer();
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "New context failed while running ";
|
||||
std::cerr << "New context failed while running " << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
device_list.push_back(device_info);
|
||||
|
||||
auto ret = model.Build(content, size, kMindIR, context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build error ";
|
||||
std::cerr << "Build error " << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &tensor : model.GetInputs()) {
|
||||
if (tensor.Shape().empty() || tensor.DataSize() <= 0 ||
|
||||
std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
|
||||
MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
|
||||
return RET_OK;
|
||||
}
|
||||
auto status = GenerateRandomData(&tensor);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
std::vector<MSTensor> outputs;
|
||||
ret = model.Predict(model.GetInputs(), &outputs);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Inference error ";
|
||||
std::cerr << "Inference error " << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RunConverter(int argc, const char **argv) {
|
||||
std::ostringstream oss;
|
||||
auto flags = std::make_unique<converter::Flags>();
|
||||
|
@ -215,6 +305,18 @@ int RunConverter(int argc, const char **argv) {
|
|||
// save graph to file
|
||||
meta_graph->version = Version();
|
||||
|
||||
if (flags->infer) {
|
||||
status = PreInference(*meta_graph, flags);
|
||||
if (status != RET_OK) {
|
||||
oss.clear();
|
||||
oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status);
|
||||
MS_LOG(ERROR) << oss.str();
|
||||
std::cout << oss.str() << std::endl;
|
||||
delete meta_graph;
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
if (flags->microParam.enable_micro) {
|
||||
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode,
|
||||
flags->microParam.target, flags->microParam.support_parallel,
|
||||
|
@ -228,7 +330,7 @@ int RunConverter(int argc, const char **argv) {
|
|||
return status;
|
||||
}
|
||||
} else {
|
||||
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile);
|
||||
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode);
|
||||
if (status != RET_OK) {
|
||||
delete meta_graph;
|
||||
oss.clear();
|
||||
|
@ -238,7 +340,12 @@ int RunConverter(int argc, const char **argv) {
|
|||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// clear key
|
||||
status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen);
|
||||
if (status != EOK) {
|
||||
MS_LOG(ERROR) << "memset failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
delete meta_graph;
|
||||
oss.clear();
|
||||
oss << "CONVERT RESULT SUCCESS:" << status;
|
||||
|
|
|
@ -83,6 +83,20 @@ Flags::Flags() {
|
|||
"");
|
||||
AddFlag(&Flags::graphInputFormatStr, "inputDataFormat",
|
||||
"Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC");
|
||||
#ifdef ENABLE_OPENSSL
|
||||
AddFlag(&Flags::encryptionStr, "encryption",
|
||||
"Whether to export the encryption model."
|
||||
"true | false",
|
||||
"true");
|
||||
AddFlag(&Flags::encKeyStr, "encryptKey",
|
||||
"The key used to encrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key "
|
||||
"length is 16.",
|
||||
"");
|
||||
#endif
|
||||
AddFlag(&Flags::inferStr, "infer",
|
||||
"Whether to do pre-inference after convert."
|
||||
"true | false",
|
||||
"false");
|
||||
}
|
||||
|
||||
int Flags::InitInputOutputDataType() {
|
||||
|
@ -310,8 +324,56 @@ int Flags::InitConfigFile() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::Init(int argc, const char **argv) {
|
||||
int ret;
|
||||
int Flags::InitSaveFP16() {
|
||||
if (saveFP16Str == "on") {
|
||||
saveFP16 = true;
|
||||
} else if (saveFP16Str == "off") {
|
||||
saveFP16 = false;
|
||||
} else {
|
||||
std::cerr << "Init save_fp16 failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitPreInference() {
|
||||
if (this->inferStr == "true") {
|
||||
this->infer = true;
|
||||
} else if (this->inferStr == "false") {
|
||||
this->infer = false;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitEncrypt() {
|
||||
if (this->encryptionStr == "true") {
|
||||
this->encryption = true;
|
||||
} else if (this->encryptionStr == "false") {
|
||||
this->encryption = false;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: encryption must be true|false " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->encryption) {
|
||||
if (encKeyStr.empty()) {
|
||||
MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen);
|
||||
if (keyLen != kEncMaxLen) {
|
||||
MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters "
|
||||
<< " and only support AES-GCM method and the key length is 16.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
encKeyStr.clear();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::PreInit(int argc, const char **argv) {
|
||||
if (argc == 1) {
|
||||
std::cout << this->Usage() << std::endl;
|
||||
return lite::RET_SUCCESS_EXIT;
|
||||
|
@ -353,19 +415,23 @@ int Flags::Init(int argc, const char **argv) {
|
|||
}
|
||||
|
||||
if (!this->configFile.empty()) {
|
||||
ret = InitConfigFile();
|
||||
auto ret = InitConfigFile();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init config file failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (saveFP16Str == "on") {
|
||||
saveFP16 = true;
|
||||
} else if (saveFP16Str == "off") {
|
||||
saveFP16 = false;
|
||||
} else {
|
||||
std::cerr << "Init save_fp16 failed." << std::endl;
|
||||
int Flags::Init(int argc, const char **argv) {
|
||||
auto ret = PreInit(argc, argv);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
ret = InitSaveFP16();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init save fp16 failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
@ -398,8 +464,25 @@ int Flags::Init(int argc, const char **argv) {
|
|||
std::cerr << "Init graph input format failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitEncrypt();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init encrypt failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitPreInference();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init pre inference failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
Flags::~Flags() {
|
||||
dec_key.clear();
|
||||
encKeyStr.clear();
|
||||
memset(encKey, 0, kEncMaxLen);
|
||||
}
|
||||
|
||||
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
|
||||
// device: [device0 device1] ---> {cpu, gpu}
|
||||
|
|
|
@ -40,6 +40,7 @@ constexpr auto kMaxSplitRatio = 10;
|
|||
constexpr auto kComputeRate = "computeRate";
|
||||
constexpr auto kSplitDevice0 = "device0";
|
||||
constexpr auto kSplitDevice1 = "device1";
|
||||
constexpr size_t kEncMaxLen = 16;
|
||||
struct ParallelSplitConfig {
|
||||
ParallelSplitType parallel_split_type_ = SplitNo;
|
||||
std::vector<int64_t> parallel_compute_rates_;
|
||||
|
@ -50,7 +51,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
public:
|
||||
Flags();
|
||||
|
||||
~Flags() override = default;
|
||||
~Flags() override;
|
||||
|
||||
int InitInputOutputDataType();
|
||||
|
||||
|
@ -66,8 +67,16 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
|
||||
int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
|
||||
|
||||
int InitEncrypt();
|
||||
|
||||
int InitPreInference();
|
||||
|
||||
int InitSaveFP16();
|
||||
|
||||
int Init(int argc, const char **argv);
|
||||
|
||||
int PreInit(int argc, const char **argv);
|
||||
|
||||
std::string modelFile;
|
||||
std::string outputFile;
|
||||
std::string fmkIn;
|
||||
|
@ -91,7 +100,19 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string graphInputFormatStr;
|
||||
std::string device;
|
||||
mindspore::Format graphInputFormat = mindspore::NHWC;
|
||||
bool enable_micro = false;
|
||||
std::string encKeyStr;
|
||||
std::string encMode = "AES-GCM";
|
||||
std::string inferStr;
|
||||
#ifdef ENABLE_OPENSSL
|
||||
std::string encryptionStr = "true";
|
||||
bool encryption = true;
|
||||
#else
|
||||
std::string encryptionStr = "false";
|
||||
bool encryption = false;
|
||||
#endif
|
||||
bool infer = false;
|
||||
unsigned char encKey[kEncMaxLen];
|
||||
size_t keyLen = 0;
|
||||
|
||||
lite::quant::CommonQuantParam commonQuantParam;
|
||||
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue