diff --git a/cmake/external_libs/openssl.cmake b/cmake/external_libs/openssl.cmake index 4bd580b0f88..9524b743c97 100644 --- a/cmake/external_libs/openssl.cmake +++ b/cmake/external_libs/openssl.cmake @@ -12,17 +12,66 @@ else() set(OPENSSL_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/openssl) endif() -if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE) - mindspore_add_pkg(openssl - VER 1.1.1k - LIBS ssl crypto - URL ${REQ_URL} - MD5 ${MD5} - CONFIGURE_COMMAND ./config no-zlib no-shared - PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch - PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch - ) - include_directories(${openssl_INC}) - add_library(mindspore::ssl ALIAS openssl::ssl) - add_library(mindspore::crypto ALIAS openssl::crypto) +if(BUILD_LITE) + if(PLATFORM_ARM64 AND ANDROID_NDK_TOOLCHAIN_INCLUDED) + set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK}) + set(PATH + ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin: + ${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin: + $ENV{PATH}) + mindspore_add_pkg(openssl + VER 1.1.1k + LIBS ssl crypto + URL ${REQ_URL} + MD5 ${MD5} + CONFIGURE_COMMAND ./Configure android-arm64 -D__ANDROID_API__=29 no-zlib + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch + ) + elseif(PLATFORM_ARM32 AND ANDROID_NDK_TOOLCHAIN_INCLUDED) + set(ANDROID_NDK_ROOT $ENV{ANDROID_NDK}) + set(PATH + ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/bin: + ${ANDROID_NDK_ROOT}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin: + $ENV{PATH}) + mindspore_add_pkg(openssl + VER 1.1.1k + LIBS ssl crypto + URL ${REQ_URL} + MD5 ${MD5} + CONFIGURE_COMMAND ./Configure android-arm -D__ANDROID_API__=29 no-zlib + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch + ) + elseif(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE) + mindspore_add_pkg(openssl + VER 1.1.1k + LIBS ssl crypto + URL ${REQ_URL} + MD5 ${MD5} + CONFIGURE_COMMAND ./config no-zlib no-shared + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch + ) + else() + MESSAGE(FATAL_ERROR "openssl does not support compilation for the current environment.") + endif() + include_directories(${openssl_INC}) + add_library(mindspore::ssl ALIAS openssl::ssl) + add_library(mindspore::crypto ALIAS openssl::crypto) +else() + if(${CMAKE_SYSTEM_NAME} MATCHES "Linux" OR APPLE) + mindspore_add_pkg(openssl + VER 1.1.1k + LIBS ssl crypto + URL ${REQ_URL} + MD5 ${MD5} + CONFIGURE_COMMAND ./config no-zlib no-shared + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3711.patch + PATCHES ${OPENSSL_PATCH_ROOT}/CVE-2021-3712.patch + ) + include_directories(${openssl_INC}) + add_library(mindspore::ssl ALIAS openssl::ssl) + add_library(mindspore::crypto ALIAS openssl::crypto) + endif() endif() diff --git a/include/api/model.h b/include/api/model.h index 7d77f74e26a..ba682550579 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -173,7 +173,7 @@ class MS_API Model { /// \return Status of operation Status UpdateFeatureMaps(const std::vector &new_weights); - /// \brief Obtains optimizer params tensors of the model. + /// \brief Obtains optimizer params tensors of the model. /// /// \return The vector that includes all params tensors. std::vector GetOptimizerParams() const; @@ -256,17 +256,14 @@ class MS_API Model { /// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite. /// /// \param[in] model_data Define the buffer read from a model file. - /// \param[in] size Define bytes number of model buffer. + /// \param[in] data_size Define bytes number of model buffer. /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. - /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. - /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// /// \return Status. - inline Status Build(const void *model_data, size_t data_size, ModelType model_type, - const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, - const std::string &dec_mode = kDecModeAesGcm); + Status Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context = nullptr); /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// @@ -274,13 +271,40 @@ class MS_API Model { /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only /// ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. - /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. - /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// /// \return Status. - inline Status Build(const std::string &model_path, ModelType model_type, - const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, - const std::string &dec_mode = kDecModeAesGcm); + Status Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &model_context = nullptr); + + /// \brief Build a model from model buffer so that it can run on a device. Only valid for Lite. + /// + /// \param[in] model_data Define the buffer read from a model file. + /// \param[in] data_size Define bytes number of model buffer. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only + /// ModelType::kMindIR is valid for Lite. + /// \param[in] model_context Define the context used to store options during execution. + /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16. + /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM. + /// \param[in] cropto_lib_path Define the openssl library path. + /// + /// \return Status. + Status Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, + const std::string &cropto_lib_path); + + /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. + /// + /// \param[in] model_path Define the model path. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only + /// ModelType::kMindIR is valid for Lite. + /// \param[in] model_context Define the context used to store options during execution. + /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16. + /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM. + /// \param[in] cropto_lib_path Define the openssl library path. + /// + /// \return Status. + Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, + const Key &dec_key, const std::string &dec_mode, const std::string &cropto_lib_path); private: friend class Serialization; @@ -291,11 +315,10 @@ class MS_API Model { std::vector GetOutputsByNodeName(const std::vector &node_name); Status LoadConfig(const std::vector &config_path); Status UpdateConfig(const std::vector §ion, const std::pair, std::vector> &config); - Status Build(const void *model_data, size_t data_size, ModelType model_type, - const std::shared_ptr &model_context, const Key &dec_key, const std::vector &dec_mode); + Status Build(const std::vector &model_path, ModelType model_type, + const std::shared_ptr &model_context); Status Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context, - const Key &dec_key, const std::vector &dec_mode); - + const Key &dec_key, const std::string &dec_mode, const std::vector &cropto_lib_path); std::shared_ptr impl_; }; @@ -321,14 +344,15 @@ Status Model::UpdateConfig(const std::string §ion, const std::pair &model_context, const Key &dec_key, const std::string &dec_mode) { - return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode)); +inline Status Model::Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &model_context, const Key &dec_key, + const std::string &dec_mode, const std::string &cropto_lib_path) { + return Build(StringToChar(model_path), model_type, model_context, dec_key, dec_mode, StringToChar(cropto_lib_path)); } -Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, - const Key &dec_key, const std::string &dec_mode) { - return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode)); +inline Status Model::Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &model_context) { + return Build(StringToChar(model_path), model_type, model_context); } } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 0c3c03c77f3..39d9fd3a6c6 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -52,14 +52,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_ return impl_->Build(); } -Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr &, const Key &, - const std::vector &) { +Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &, const Key &, + const std::string &, const std::vector &) { MS_LOG(ERROR) << "Unsupported Feature."; return kMCFailed; } -Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &, const Key &, - const std::vector &) { +Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &) { MS_LOG(ERROR) << "Unsupported Feature."; return kMCFailed; } diff --git a/mindspore/core/utils/crypto.cc b/mindspore/core/utils/crypto.cc index f47b409a0bf..6c91be29ce1 100644 --- a/mindspore/core/utils/crypto.cc +++ b/mindspore/core/utils/crypto.cc @@ -120,7 +120,7 @@ bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work } EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv, - bool is_encrypt) { + int iv_len, bool is_encrypt) { constexpr int32_t key_length_16 = 16; constexpr int32_t key_length_24 = 24; constexpr int32_t key_length_32 = 32; @@ -163,8 +163,35 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i int32_t ret = 0; auto ctx = EVP_CIPHER_CTX_new(); if (is_encrypt) { + ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) { + MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } } else { + ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptInit_ex failed"; + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) { + MS_LOG(ERROR) << "EVP_DecryptInit_ex failed"; + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); } @@ -183,7 +210,7 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i } bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector &plain_data, const Byte *key, - int32_t key_len, const std::string &enc_mode) { + int32_t key_len, const std::string &enc_mode, unsigned char *tag) { size_t encrypt_data_buf_len = *encrypt_data_len; int32_t cipher_len = 0; int32_t iv_len = AES_BLOCK_SIZE; @@ -201,7 +228,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto return false; } - auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), true); + auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast(iv.size()), true); if (ctx == nullptr) { MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; return false; @@ -214,15 +241,19 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; return false; } - if (work_mode == "CBC") { - int32_t flen = 0; - ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen); - if (ret_evp != 1) { - MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed"; - return false; - } - cipher_len += flen; + int32_t flen = 0; + ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen); + if (ret_evp != 1) { + MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed"; + return false; } + cipher_len += flen; + + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, Byte16, tag) != 1) { + MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed"; + return false; + } + EVP_CIPHER_CTX_free(ctx); size_t offset = 0; @@ -266,7 +297,7 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto } bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key, - int32_t key_len, const std::string &dec_mode) { + int32_t key_len, const std::string &dec_mode, unsigned char *tag) { std::string alg_mode; std::string work_mode; if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { @@ -277,7 +308,7 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) { return false; } - auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), false); + auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), iv.size(), false); if (ctx == nullptr) { MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; return false; @@ -288,15 +319,20 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; return false; } - if (work_mode == "CBC") { - int32_t mlen = 0; - ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen); - if (ret != 1) { - MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; - return false; - } - *plain_len += mlen; + + if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) { + MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed"; + return false; } + + int32_t mlen = 0; + ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; + return false; + } + *plain_len += mlen; + EVP_CIPHER_CTX_free(ctx); return true; } @@ -319,7 +355,9 @@ std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, siz size_t block_enc_len = block_enc_buf.size(); size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset); block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size); - if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast(key_len), enc_mode)) { + unsigned char tag[Byte16]; + if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast(key_len), enc_mode, + tag)) { MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; return nullptr; } @@ -332,6 +370,13 @@ std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, siz } *encrypt_len += sizeof(int32_t); + capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb + ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, tag, Byte16); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; + } + *encrypt_len += Byte16; + capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len); if (ret != 0) { @@ -371,6 +416,10 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_ MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted"; return nullptr; } + + unsigned char tag[Byte16]; + fid.read(reinterpret_cast(tag), Byte16); + fid.read(int_buf.data(), static_cast(sizeof(int32_t))); auto block_size = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); if (block_size < 0) { @@ -379,7 +428,7 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_ } fid.read(block_buf.data(), static_cast(block_size)); if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast(block_buf.data()), - static_cast(block_size), key, static_cast(key_len), dec_mode))) { + static_cast(block_size), key, static_cast(key_len), dec_mode, tag))) { MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; return nullptr; } @@ -409,6 +458,10 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, siz size_t offset = 0; *decrypt_len = 0; while (offset < data_size) { + if (offset + sizeof(int32_t) > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); offset += int_buf.size(); auto cipher_flag = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); @@ -416,27 +469,44 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, siz MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted."; return nullptr; } - + unsigned char tag[Byte16]; + if (offset + Byte16 > data_size) { + MS_LOG(ERROR) << "buffer is invalid."; + return nullptr; + } + auto ret = memcpy_s(tag, Byte16, model_data + offset, Byte16); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s failed " << ret; + } + offset += Byte16; + if (offset + sizeof(int32_t) > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); offset += int_buf.size(); auto block_size = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); - if (block_size < 0) { + if (block_size <= 0) { MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size; return nullptr; } + if (offset + block_size > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } block_buf.assign(model_data + offset, model_data + offset + block_size); offset += block_buf.size(); if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast(block_buf.data()), - block_buf.size(), key, static_cast(key_len), dec_mode))) { + block_buf.size(), key, static_cast(key_len), dec_mode, tag))) { MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; return nullptr; } - size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN); - auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), - static_cast(decrypt_block_len)); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; + ret = memcpy_s(decrypt_data.get() + *decrypt_len, data_size, decrypt_block_buf.data(), + static_cast(decrypt_block_len)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s failed " << ret; } + *decrypt_len += static_cast(decrypt_block_len); } return decrypt_data; diff --git a/mindspore/core/utils/crypto.h b/mindspore/core/utils/crypto.h index df400a67ef9..4dc43db04a0 100644 --- a/mindspore/core/utils/crypto.h +++ b/mindspore/core/utils/crypto.h @@ -26,6 +26,7 @@ namespace mindspore { constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte constexpr size_t RESERVED_BYTE_PER_BLOCK = 50; // Reserved byte per block to save addition info constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number +constexpr size_t Byte16 = 16; MS_CORE_API std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, const Byte *key, size_t key_len, const std::string &enc_mode); diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 8f037644490..c900ebcfdc9 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -34,7 +34,7 @@ option(MSLITE_ENABLE_V0 "support v0 schema" on) option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off) option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on) option(MSLITE_ENABLE_ACL "enable ACL" off) -option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" on) +option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption, only converter support" off) option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off) option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off) option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off) @@ -127,7 +127,11 @@ if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT}) set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT}) endif() if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) - set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) + if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64) OR (PLATFORM_ARM AND ANDROID_NDK_TOOLCHAIN_INCLUDED)) + set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) + else() + set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF) + endif() endif() if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT}) @@ -227,7 +231,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32) endif() set(MSLITE_ENABLE_RUNTIME_GLOG off) set(MSLITE_ENABLE_RUNTIME_CONVERT off) -#set for cross - compiling toolchain + #set for cross - compiling toolchain set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) @@ -540,13 +544,15 @@ if(MSLITE_ENABLE_CONVERTER) include_directories(${PYTHON_INCLUDE_DIRS}) include(${TOP_DIR}/cmake/external_libs/eigen.cmake) include(${TOP_DIR}/cmake/external_libs/protobuf.cmake) - if(MSLITE_ENABLE_MODEL_ENCRYPTION) - find_package(Patch) - include(${TOP_DIR}/cmake/external_libs/openssl.cmake) - endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) endif() +if(MSLITE_ENABLE_MODEL_ENCRYPTION) + find_package(Patch) + include(${TOP_DIR}/cmake/external_libs/openssl.cmake) + add_compile_definitions(ENABLE_OPENSSL) +endif() + if(MSLITE_ENABLE_MINDRT) add_compile_definitions(ENABLE_MINDRT) endif() @@ -590,7 +596,7 @@ if(NOT PLATFORM_ARM) endif() if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" - OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") + OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") add_compile_definitions(ENABLE_ANDROID) if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) add_compile_definitions(ENABLE_MD_LITE_X86_64) @@ -605,7 +611,7 @@ endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/ops) if(ANDROID_NDK_TOOLCHAIN_INCLUDED) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) diff --git a/mindspore/lite/build_lite.sh b/mindspore/lite/build_lite.sh index 7de4d82e2d4..ee6c7daa79f 100755 --- a/mindspore/lite/build_lite.sh +++ b/mindspore/lite/build_lite.sh @@ -206,6 +206,7 @@ build_lite() { LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" else checkndk + export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH} CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on" @@ -237,6 +238,7 @@ build_lite() { ARM64_COMPILE_CONVERTER=ON else checkndk + export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH} CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}" LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=lite_cv" diff --git a/mindspore/lite/java/src/main/java/com/mindspore/Model.java b/mindspore/lite/java/src/main/java/com/mindspore/Model.java index 2289751a478..4804cf6638c 100644 --- a/mindspore/lite/java/src/main/java/com/mindspore/Model.java +++ b/mindspore/lite/java/src/main/java/com/mindspore/Model.java @@ -58,19 +58,19 @@ public class Model { /** * Build model. * - * @param buffer model buffer. - * @param modelType model type. - * @param context model build context. - * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. - * @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. + * @param buffer model buffer. + * @param modelType model type. + * @param context model build context. + * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. + * @param dec_mode define the decryption mode. Options: AES-GCM. + * @param cropto_lib_path define the openssl library path. * @return model build status. */ - public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, - String dec_mode) { + public boolean build(final MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { if (context == null || buffer == null || dec_key == null || dec_mode == null) { return false; } - modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode); + modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); return modelPtr != 0; } @@ -86,7 +86,7 @@ public class Model { if (context == null || buffer == null) { return false; } - modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, ""); + modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", ""); return modelPtr != 0; } @@ -94,18 +94,19 @@ public class Model { /** * Build model. * - * @param modelPath model path. - * @param modelType model type. - * @param context model build context. - * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. - * @param dec_mode define the decryption mode. Options: AES-GCM, AES-CBC. + * @param modelPath model path. + * @param modelType model type. + * @param context model build context. + * @param dec_key define the key used to decrypt the ciphertext model. The key length is 16. + * @param dec_mode define the decryption mode. Options: AES-GCM. + * @param cropto_lib_path define the openssl library path. * @return model build status. */ - public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode) { + public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) { if (context == null || modelPath == null || dec_key == null || dec_mode == null) { return false; } - modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode); + modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path); return modelPtr != 0; } @@ -121,7 +122,7 @@ public class Model { if (context == null || modelPath == null) { return false; } - modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, ""); + modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", ""); return modelPtr != 0; } @@ -256,8 +257,7 @@ public class Model { * @param outputTensorNames tensor name used for export inference graph. * @return Whether the export is successful. */ - public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, - List outputTensorNames) { + public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, List outputTensorNames) { if (fileName == null) { return false; } @@ -355,10 +355,11 @@ public class Model { private native long buildByGraph(long graphPtr, long contextPtr, long cfgPtr); - private native long buildByPath(String modelPath, int modelType, long contextPtr, char[] dec_key, String dec_mod); + private native long buildByPath(String modelPath, int modelType, long contextPtr, + char[] dec_key, String dec_mod, String cropto_lib_path); - private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, char[] dec_key, - String dec_mod); + private native long buildByBuffer(MappedByteBuffer buffer, int modelType, long contextPtr, + char[] dec_key, String dec_mod, String cropto_lib_path); private native List getInputs(long modelPtr); @@ -380,8 +381,7 @@ public class Model { private native boolean resize(long modelPtr, long[] inputs, int[][] dims); - private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, - String[] outputTensorNames); + private native boolean export(long modelPtr, String fileName, int quantizationType, boolean isOnlyExportInfer, String[] outputTensorNames); private native List getFeatureMaps(long modelPtr); @@ -389,6 +389,5 @@ public class Model { private native boolean setLearningRate(long modelPtr, float learning_rate); - private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, - float momentum); + private native boolean setupVirtualBatch(long modelPtr, int virtualBatchMultiplier, float learningRate, float momentum); } diff --git a/mindspore/lite/java/src/main/native/model.cpp b/mindspore/lite/java/src/main/native/model.cpp index 77a3a77801b..8c7339822fd 100644 --- a/mindspore/lite/java/src/main/native/model.cpp +++ b/mindspore/lite/java/src/main/native/model.cpp @@ -68,7 +68,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv *env, jobject thiz, jobject model_buffer, jint model_type, jlong context_ptr, jcharArray key_str, - jstring dec_mod) { + jstring dec_mod, jstring cropto_lib_path) { if (model_buffer == nullptr) { MS_LOGE("Buffer from java is nullptr"); return reinterpret_cast(nullptr); @@ -116,7 +116,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv } env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); mindspore::Key dec_key{dec_key_data, key_len}; - status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod); + auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); + status = model->Build(model_buf, buffer_len, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); } else { status = model->Build(model_buf, buffer_len, c_model_type, context); } @@ -130,7 +131,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByBuffer(JNIEnv extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv *env, jobject thiz, jstring model_path, jint model_type, jlong context_ptr, - jcharArray key_str, jstring dec_mod) { + jcharArray key_str, jstring dec_mod, + jstring cropto_lib_path) { auto c_model_path = env->GetStringUTFChars(model_path, JNI_FALSE); mindspore::ModelType c_model_type; if (model_type >= static_cast(mindspore::kMindIR) && model_type <= static_cast(mindspore::kMindIR_Lite)) { @@ -172,7 +174,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByPath(JNIEnv * } env->ReleaseCharArrayElements(key_str, key_array, JNI_ABORT); mindspore::Key dec_key{dec_key_data, key_len}; - status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod); + auto c_cropto_lib_path = env->GetStringUTFChars(cropto_lib_path, JNI_FALSE); + status = model->Build(c_model_path, c_model_type, context, dec_key, c_dec_mod, c_cropto_lib_path); } else { status = model->Build(c_model_path, c_model_type, context); } diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index fbdd3022665..5e1d69162c8 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -131,6 +131,14 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/cpu_info.cc ) +if(MSLITE_ENABLE_MODEL_ENCRYPTION) + set(LITE_SRC + ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/common/decrypt.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc + ) +endif() + if(MSLITE_ENABLE_SERVER_INFERENCE) set(LITE_SRC ${LITE_SRC} @@ -272,8 +280,7 @@ set(TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc - ${TOOLS_DIR}/common/storage.cc - ${TOOLS_DIR}/common/meta_graph_serializer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc ${TOOLS_DIR}/converter/optimizer.cc ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc diff --git a/mindspore/lite/src/common/decrypt.cc b/mindspore/lite/src/common/decrypt.cc new file mode 100644 index 00000000000..466efe70712 --- /dev/null +++ b/mindspore/lite/src/common/decrypt.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/decrypt.h" +#ifdef ENABLE_OPENSSL +#include +#include +#include +#include +#include +#include +#include +#include "src/common/dynamic_library_loader.h" +#include "src/common/file_utils.h" +#endif +#include "src/common/log_adapter.h" +#include "src/common/log_util.h" + +#ifndef SECUREC_MEM_MAX_LEN +#define SECUREC_MEM_MAX_LEN 0x7fffffffUL +#endif + +namespace mindspore::lite { +#ifndef ENABLE_OPENSSL +std::unique_ptr Decrypt(const std::string &lib_path, size_t *, const Byte *, const size_t, const Byte *, + const size_t, const std::string &) { + MS_LOG(ERROR) << "The feature is only supported on the Linux platform " + "when the OPENSSL compilation option is enabled."; + return nullptr; +} +#else +namespace { +constexpr size_t MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte +constexpr size_t Byte16 = 16; // Byte16 +constexpr unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number +DynamicLibraryLoader loader; +} // namespace +int32_t ByteToInt(const Byte *byteArray, size_t length) { + if (byteArray == nullptr) { + MS_LOG(ERROR) << "There is a null pointer in the input parameter."; + return -1; + } + if (length < sizeof(int32_t)) { + MS_LOG(ERROR) << "Length of byteArray is " << length << ", less than sizeof(int32_t): 4."; + return -1; + } + return *(reinterpret_cast(byteArray)); +} + +bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector *iv, + std::vector *cipher_data) { + if (encrypt_data == nullptr || iv == nullptr || cipher_data == nullptr) { + MS_LOG(ERROR) << "There is a null pointer in the input parameter."; + return false; + } + // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data + std::vector int_buf(sizeof(int32_t)); + if (sizeof(int32_t) > encrypt_len) { + MS_LOG(ERROR) << "assign len is invalid."; + return false; + } + int_buf.assign(encrypt_data, encrypt_data + sizeof(int32_t)); + auto iv_len = ByteToInt(int_buf.data(), int_buf.size()); + if (iv_len <= 0 || iv_len + sizeof(int32_t) + sizeof(int32_t) > encrypt_len) { + MS_LOG(ERROR) << "assign len is invalid."; + return false; + } + int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t)); + auto cipher_len = ByteToInt(int_buf.data(), int_buf.size()); + if (((static_cast(iv_len) + sizeof(int32_t) + static_cast(cipher_len) + sizeof(int32_t)) != + encrypt_len)) { + MS_LOG(ERROR) << "Failed to parse encrypt data."; + return false; + } + + (*iv).assign(encrypt_data + sizeof(int32_t), encrypt_data + sizeof(int32_t) + iv_len); + if (cipher_len <= 0 || sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len > encrypt_len) { + MS_LOG(ERROR) << "assign len is invalid."; + return false; + } + (*cipher_data) + .assign(encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t), + encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len); + return true; +} + +bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) { + if (alg_mode == nullptr || work_mode == nullptr) { + MS_LOG(ERROR) << "There is a null pointer in the input parameter."; + return false; + } + std::smatch results; + std::regex re("([A-Z]{3})-([A-Z]{3})"); + if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) { + MS_LOG(ERROR) << "Mode " << mode << " is invalid."; + return false; + } + const size_t index_1 = 1; + const size_t index_2 = 2; + *alg_mode = results[index_1]; + *work_mode = results[index_2]; + return true; +} + +EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv, + int iv_len) { + constexpr int32_t key_length_16 = 16; + constexpr int32_t key_length_24 = 24; + constexpr int32_t key_length_32 = 32; + const EVP_CIPHER *(*funcPtr)() = nullptr; + if (work_mode == "GCM") { + switch (key_len) { + case key_length_16: + funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_128_gcm"); + break; + case key_length_24: + funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_192_gcm"); + break; + case key_length_32: + funcPtr = (const EVP_CIPHER *(*)())loader.GetFunc("EVP_aes_256_gcm"); + break; + default: + MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; + return nullptr; + } + } else { + MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid."; + return nullptr; + } + + int32_t ret = 0; + EVP_CIPHER_CTX *(*EVP_CIPHER_CTX_new)() = (EVP_CIPHER_CTX * (*)()) loader.GetFunc("EVP_CIPHER_CTX_new"); + EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); + if (ctx == nullptr) { + MS_LOG(ERROR) << "EVP_CIPHER_CTX_new failed"; + return nullptr; + } + + int (*EVP_DecryptInit_ex)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *, + const unsigned char *) = + (int (*)(EVP_CIPHER_CTX *, const EVP_CIPHER *, ENGINE *, const unsigned char *, + const unsigned char *))loader.GetFunc("EVP_DecryptInit_ex"); + ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, NULL, NULL); + int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) = + (int (*)(EVP_CIPHER_CTX * ctx, int type, int arg, void *ptr)) loader.GetFunc("EVP_CIPHER_CTX_ctrl"); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptInit_ex failed"; + void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free"); + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL) != 1) { + MS_LOG(ERROR) << "EVP_DecryptInit_ex failed"; + void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free"); + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } + ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptInit_ex failed"; + void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free"); + EVP_CIPHER_CTX_free(ctx); + return nullptr; + } + return ctx; +} + +bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key, + int32_t key_len, const std::string &dec_mode, unsigned char *tag) { + if (plain_data == nullptr || plain_len == nullptr || encrypt_data == nullptr || key == nullptr) { + MS_LOG(ERROR) << "There is a null pointer in the input parameter."; + return false; + } + std::string alg_mode; + std::string work_mode; + if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { + return false; + } + std::vector iv; + std::vector cipher_data; + if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) { + return false; + } + auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), static_cast(iv.size())); + if (ctx == nullptr) { + MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; + return false; + } + int (*EVP_DecryptUpdate)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int) = + (int (*)(EVP_CIPHER_CTX *, unsigned char *, int *, const unsigned char *, int))loader.GetFunc("EVP_DecryptUpdate"); + auto ret = + EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast(cipher_data.size())); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; + return false; + } + + int (*EVP_CIPHER_CTX_ctrl)(EVP_CIPHER_CTX *, int, int, void *) = + (int (*)(EVP_CIPHER_CTX *, int, int, void *))loader.GetFunc("EVP_CIPHER_CTX_ctrl"); + if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, Byte16, tag)) { + MS_LOG(ERROR) << "EVP_CIPHER_CTX_ctrl failed"; + return false; + } + + int32_t mlen = 0; + int (*EVP_DecryptFinal_ex)(EVP_CIPHER_CTX *, unsigned char *, int *) = + (int (*)(EVP_CIPHER_CTX *, unsigned char *, int *))loader.GetFunc("EVP_DecryptFinal_ex"); + ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; + return false; + } + *plain_len += mlen; + + void (*EVP_CIPHER_CTX_free)(EVP_CIPHER_CTX *) = (void (*)(EVP_CIPHER_CTX *))loader.GetFunc("EVP_CIPHER_CTX_free"); + EVP_CIPHER_CTX_free(ctx); + iv.assign(iv.size(), 0); + return true; +} + +std::unique_ptr Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data, + const size_t data_size, const Byte *key, const size_t key_len, + const std::string &dec_mode) { + if (model_data == nullptr) { + MS_LOG(ERROR) << "model_data is nullptr."; + return nullptr; + } + if (key == nullptr) { + MS_LOG(ERROR) << "key is nullptr."; + return nullptr; + } + if (decrypt_len == nullptr) { + MS_LOG(ERROR) << "decrypt_len is nullptr."; + return nullptr; + } + auto ret = loader.Open(lib_path); + if (ret != RET_OK) { + MS_LOG(ERROR) << "loader open failed."; + return nullptr; + } + std::vector block_buf; + std::vector int_buf(sizeof(int32_t)); + std::vector decrypt_block_buf(MAX_BLOCK_SIZE); + auto decrypt_data = std::make_unique(data_size); + int32_t decrypt_block_len; + + size_t offset = 0; + *decrypt_len = 0; + if (dec_mode != "AES-GCM") { + MS_LOG(ERROR) << "dec_mode only support AES-GCM."; + return nullptr; + } + if (key_len != Byte16) { + MS_LOG(ERROR) << "key_len only support 16."; + return nullptr; + } + while (offset < data_size) { + if (offset + sizeof(int32_t) > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } + int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); + offset += int_buf.size(); + auto cipher_flag = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); + if (static_cast(cipher_flag) != MAGIC_NUM) { + MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted."; + return nullptr; + } + unsigned char tag[Byte16]; + if (offset + Byte16 > data_size) { + MS_LOG(ERROR) << "buffer is invalid."; + return nullptr; + } + memcpy(tag, model_data + offset, Byte16); + offset += Byte16; + if (offset + sizeof(int32_t) > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } + int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); + offset += int_buf.size(); + auto block_size = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); + if (block_size <= 0) { + MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size; + return nullptr; + } + if (offset + block_size > data_size) { + MS_LOG(ERROR) << "assign len is invalid."; + return nullptr; + } + block_buf.assign(model_data + offset, model_data + offset + block_size); + offset += block_buf.size(); + if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast(block_buf.data()), + block_buf.size(), key, static_cast(key_len), dec_mode, tag))) { + MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; + return nullptr; + } + memcpy(decrypt_data.get() + *decrypt_len, decrypt_block_buf.data(), static_cast(decrypt_block_len)); + + *decrypt_len += static_cast(decrypt_block_len); + } + ret = loader.Close(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "loader close failed."; + return nullptr; + } + return decrypt_data; +} +#endif +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/decrypt.h b/mindspore/lite/src/common/decrypt.h new file mode 100644 index 00000000000..e9cc2c49c63 --- /dev/null +++ b/mindspore/lite/src/common/decrypt.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_DECRYPT_H_ +#define MINDSPORE_CORE_UTILS_DECRYPT_H_ + +#include +#include + +typedef unsigned char Byte; +namespace mindspore::lite { +std::unique_ptr Decrypt(const std::string &lib_path, size_t *decrypt_len, const Byte *model_data, + const size_t data_size, const Byte *key, const size_t key_len, + const std::string &dec_mode); +} // namespace mindspore::lite +#endif diff --git a/mindspore/lite/src/common/dynamic_library_loader.cc b/mindspore/lite/src/common/dynamic_library_loader.cc index c380b08dce1..954b6b07f3b 100644 --- a/mindspore/lite/src/common/dynamic_library_loader.cc +++ b/mindspore/lite/src/common/dynamic_library_loader.cc @@ -30,10 +30,13 @@ namespace mindspore { namespace lite { int DynamicLibraryLoader::Open(const std::string &lib_path) { if (handler_ != nullptr) { - return RET_ERROR; + return RET_OK; } std::string real_path = RealPath(lib_path.c_str()); - + if (real_path.empty()) { + MS_LOG(ERROR) << "real_path is invalid."; + return RET_ERROR; + } #ifndef _WIN32 #ifndef ENABLE_ARM handler_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_DEEPBIND); diff --git a/mindspore/lite/src/common/storage.cc b/mindspore/lite/src/common/storage.cc new file mode 100644 index 00000000000..b39051fab84 --- /dev/null +++ b/mindspore/lite/src/common/storage.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/storage.h" +#include +#ifndef _MSC_VER +#include +#endif +#include "flatbuffers/flatbuffers.h" +#include "src/common/log_adapter.h" +#include "src/common/file_utils.h" + +namespace mindspore { +namespace lite { +namespace { +constexpr size_t kMaxNum1024 = 1024; +} +int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath) { + flatbuffers::FlatBufferBuilder builder(kMaxNum1024); + auto offset = schema::MetaGraph::Pack(builder, &graph); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + int size = builder.GetSize(); + auto content = builder.GetBufferPointer(); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + std::string filename = outputPath; + if (filename.length() == 0) { + MS_LOG(ERROR) << "Invalid output path."; + return RET_ERROR; + } + if (filename.substr(filename.find_last_of(".") + 1) != "ms") { + filename = filename + ".ms"; + } +#ifndef _MSC_VER + if (access(filename.c_str(), F_OK) == 0) { + chmod(filename.c_str(), S_IWUSR); + } +#endif + std::ofstream output(filename, std::ofstream::binary); + if (!output.is_open()) { + MS_LOG(ERROR) << "Can not open output file: " << filename; + return RET_ERROR; + } + output.write((const char *)content, size); + if (output.bad()) { + output.close(); + MS_LOG(ERROR) << "Write output file : " << filename << " failed"; + return RET_ERROR; + } + output.close(); +#ifndef _MSC_VER + chmod(filename.c_str(), S_IRUSR); +#endif + return RET_OK; +} + +schema::MetaGraphT *Storage::Load(const std::string &inputPath) { + size_t size = 0; + std::string filename = inputPath; + if (filename.length() == 0) { + MS_LOG(ERROR) << "Invalid input path."; + return nullptr; + } + if (filename.substr(filename.find_last_of(".") + 1) != "ms") { + filename = filename + ".ms"; + } + auto buf = ReadFile(filename.c_str(), &size); + if (buf == nullptr) { + MS_LOG(ERROR) << "the file buffer is nullptr"; + return nullptr; + } + + flatbuffers::Verifier verify((const uint8_t *)buf, size); + if (!schema::VerifyMetaGraphBuffer(verify)) { + MS_LOG(ERROR) << "the buffer is invalid and fail to create meta graph"; + return nullptr; + } + + auto graphDefT = schema::UnPackMetaGraph(buf); + return graphDefT.release(); +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/storage.h b/mindspore/lite/src/common/storage.h new file mode 100644 index 00000000000..4d30d982754 --- /dev/null +++ b/mindspore/lite/src/common/storage.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_COMMON_STORAGE_H +#define MINDSPORE_LITE_SRC_COMMON_STORAGE_H + +#include +#include +#include "include/errorcode.h" +#include "flatbuffers/flatbuffers.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +class Storage { + public: + static int Save(const schema::MetaGraphT &graph, const std::string &outputPath); + + static schema::MetaGraphT *Load(const std::string &inputPath); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_COMMON_STORAGE_H diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index db1dd6b8fdf..c4a599f0ff7 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -30,16 +30,74 @@ #include "src/cxx_api/callback/callback_adapter.h" #include "src/cxx_api/callback/callback_impl.h" #include "src/cxx_api/model/model_impl.h" +#ifdef ENABLE_OPENSSL +#include "src/common/decrypt.h" +#include "src/common/file_utils.h" +#endif namespace mindspore { std::mutex g_impl_init_lock; +#ifdef ENABLE_OPENSSL +Status DecryptModel(const std::string &cropto_lib_path, const void *model_buf, size_t model_size, const Key &dec_key, + const std::string &dec_mode, std::unique_ptr *decrypt_buffer, size_t *decrypt_len) { + if (model_buf == nullptr) { + MS_LOG(ERROR) << "model_buf is nullptr."; + return kLiteError; + } + *decrypt_len = 0; + *decrypt_buffer = lite::Decrypt(cropto_lib_path, decrypt_len, reinterpret_cast(model_buf), model_size, + dec_key.key, dec_key.len, dec_mode); + if (*decrypt_buffer == nullptr || *decrypt_len == 0) { + MS_LOG(ERROR) << "Decrypt buffer failed"; + return kLiteError; + } + return kSuccess; +} +#endif Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, - const std::shared_ptr &model_context, const Key &dec_key, - const std::vector &dec_mode) { + const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, + const std::string &cropto_lib_path) { +#ifdef ENABLE_OPENSSL if (impl_ == nullptr) { std::unique_lock impl_lock(g_impl_init_lock); - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteFileError; + } + } + if (dec_key.len > 0) { + std::unique_ptr decrypt_buffer; + size_t decrypt_len = 0; + Status ret = DecryptModel(cropto_lib_path, model_data, data_size, dec_key, dec_mode, &decrypt_buffer, &decrypt_len); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Decrypt model failed."; + return ret; + } + ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Build model failed."; + return ret; + } + } else { + Status ret = impl_->Build(model_data, data_size, model_type, model_context); + if (ret != kSuccess) { + return ret; + } + } + return kSuccess; +#else + MS_LOG(ERROR) << "The lib is not support Decrypt Model."; + return kLiteError; +#endif +} + +Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context) { + if (impl_ == nullptr) { + std::unique_lock impl_lock(g_impl_init_lock); + impl_ = std::make_shared(); if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteFileError; @@ -54,11 +112,59 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty } Status Model::Build(const std::vector &model_path, ModelType model_type, - const std::shared_ptr &model_context, const Key &dec_key, - const std::vector &dec_mode) { + const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode, + const std::vector &cropto_lib_path) { +#ifdef ENABLE_OPENSSL if (impl_ == nullptr) { std::unique_lock impl_lock(g_impl_init_lock); - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteFileError; + } + } + if (dec_key.len > 0) { + size_t model_size; + auto model_buf = lite::ReadFile(model_path.data(), &model_size); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "Read model file failed"; + return kLiteError; + } + std::unique_ptr decrypt_buffer; + size_t decrypt_len = 0; + Status ret = DecryptModel(CharToString(cropto_lib_path), model_buf, model_size, dec_key, dec_mode, &decrypt_buffer, + &decrypt_len); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Decrypt model failed."; + delete[] model_buf; + return ret; + } + ret = impl_->Build(decrypt_buffer.get(), decrypt_len, model_type, model_context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Build model failed."; + delete[] model_buf; + return ret; + } + delete[] model_buf; + } else { + Status ret = impl_->Build(CharToString(model_path), model_type, model_context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Build model failed."; + return ret; + } + } + return kSuccess; +#else + MS_LOG(ERROR) << "The lib is not support Decrypt Model."; + return kLiteError; +#endif +} + +Status Model::Build(const std::vector &model_path, ModelType model_type, + const std::shared_ptr &model_context) { + if (impl_ == nullptr) { + std::unique_lock impl_lock(g_impl_init_lock); + impl_ = std::make_shared(); if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteFileError; @@ -77,7 +183,7 @@ Status Model::Build(GraphCell graph, const std::shared_ptr &model_conte std::stringstream err_msg; if (impl_ == nullptr) { std::unique_lock impl_lock(g_impl_init_lock); - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteFileError; @@ -258,7 +364,7 @@ Status Model::LoadConfig(const std::vector &config_path) { return Status(kLiteFileError, "Illegal operation."); } - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); if (impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return Status(kLiteFileError, "Fail to load config file."); @@ -276,7 +382,7 @@ Status Model::UpdateConfig(const std::vector §ion, const std::pair, std::vector> &config) { std::unique_lock impl_lock(g_impl_init_lock); if (impl_ == nullptr) { - impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); + impl_ = std::make_shared(); } if (impl_ != nullptr) { return impl_->UpdateConfig(CharToString(section), {CharToString(config.first), CharToString(config.second)}); @@ -388,5 +494,4 @@ float Model::GetLearningRate() { } return impl_->GetLearningRate(); } - } // namespace mindspore diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc index 6c385c6bed0..e0265ad3083 100644 --- a/mindspore/lite/src/train/train_export.cc +++ b/mindspore/lite/src/train/train_export.cc @@ -26,7 +26,7 @@ #include "schema/inner/model_generated.h" #include "src/train/train_utils.h" #include "src/common/quant_utils.h" -#include "tools/common/meta_graph_serializer.h" +#include "src/common/storage.h" #include "src/train/graph_fusion.h" #include "src/train/graph_dropout.h" #include "src/weight_decoder.h" @@ -553,7 +553,7 @@ int TrainExport::ExportInit(const std::string model_name, std::string version) { return RET_OK; } -int TrainExport::SaveToFile() { return MetaGraphSerializer::Save(*meta_graph_, file_name_); } +int TrainExport::SaveToFile() { return Storage::Save(*meta_graph_, file_name_); } bool TrainExport::IsInputTensor(const schema::TensorT &t) { int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies()); diff --git a/mindspore/lite/test/config/cropped_size.cfg b/mindspore/lite/test/config/cropped_size.cfg index 3c33bc01d40..6628b3691d2 100644 --- a/mindspore/lite/test/config/cropped_size.cfg +++ b/mindspore/lite/test/config/cropped_size.cfg @@ -1 +1 @@ -844020 +848116 diff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h index af01a017f96..72a310948d2 100644 --- a/mindspore/lite/tools/benchmark/benchmark_base.h +++ b/mindspore/lite/tools/benchmark/benchmark_base.h @@ -69,6 +69,7 @@ constexpr int kNumPrintMin = 5; constexpr const char *DELIM_COLON = ":"; constexpr const char *DELIM_COMMA = ","; constexpr const char *DELIM_SLASH = "/"; +constexpr size_t kEncMaxLen = 16; extern const std::unordered_map kTypeIdMap; extern const std::unordered_map kTensorFormatMap; @@ -139,6 +140,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1); AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes", "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); + AddFlag(&BenchmarkFlags::decrypt_key_str_, "decryptKey", + "The key used to decrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key " + "length is 16.", + ""); + AddFlag(&BenchmarkFlags::crypto_lib_path_, "cryptoLibPath", "Pass the crypto library path.", ""); AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false", false); AddFlag(&BenchmarkFlags::parallel_request_num_, "parallelRequestNum", "parallel request num of parallel predict", @@ -192,6 +198,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser { std::string perf_event_ = "CYCLE"; bool dump_tensor_data_ = false; bool print_tensor_data_ = false; + std::string decrypt_key_str_; + std::string dec_mode_ = "AES-GCM"; + std::string crypto_lib_path_; }; class MS_API BenchmarkBase { diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc index 54e49e61040..d8be3e642d8 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc @@ -698,7 +698,6 @@ int BenchmarkUnifiedApi::CompareDataGetTotalBiasAndSize(const std::string &name, *total_size += 1; return RET_OK; } - int BenchmarkUnifiedApi::CompareDataGetTotalCosineDistanceAndSize(const std::string &name, mindspore::MSTensor *tensor, float *total_cosine_distance, int *total_size) { if (tensor == nullptr) { @@ -1044,6 +1043,33 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr contex } #endif +int BenchmarkUnifiedApi::CompileGraph(ModelType model_type, const std::shared_ptr &context, + const std::string &model_name) { + Key dec_key; + if (!flags_->decrypt_key_str_.empty()) { + dec_key.len = lite::Hex2ByteArray(flags_->decrypt_key_str_, dec_key.key, kEncMaxLen); + if (dec_key.len == 0) { + MS_LOG(ERROR) << "dec_key.len == 0"; + return RET_INPUT_PARAM_INVALID; + } + flags_->decrypt_key_str_.clear(); + } + Status ret; + if (flags_->crypto_lib_path_.empty()) { + ret = ms_model_.Build(flags_->model_file_, model_type, context); + } else { + ret = + ms_model_.Build(flags_->model_file_, model_type, context, dec_key, flags_->dec_mode_, flags_->crypto_lib_path_); + } + memset(dec_key.key, 0, kEncMaxLen); + if (ret != kSuccess) { + MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str(); + std::cout << "ms_model_.Build failed while running ", model_name.c_str(); + return RET_ERROR; + } + return RET_OK; +} + int BenchmarkUnifiedApi::RunBenchmark() { auto start_prepare_time = GetTimeUs(); @@ -1098,19 +1124,17 @@ int BenchmarkUnifiedApi::RunBenchmark() { } #endif - auto ret = ms_model_.Build(flags_->model_file_, model_type, context); - if (ret != kSuccess) { - MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str(); - std::cout << "ms_model_.Build failed while running ", model_name.c_str(); - return RET_ERROR; + status = CompileGraph(model_type, context, model_name); + if (status != RET_OK) { + MS_LOG(ERROR) << "Compile graph failed."; + return status; } - if (!flags_->resize_dims_.empty()) { std::vector> resize_dims; (void)std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(resize_dims), [&](auto &shapes) { return this->ConverterToInt64Vector(shapes); }); - ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims); + auto ret = ms_model_.Resize(ms_model_.GetInputs(), resize_dims); if (ret != kSuccess) { MS_LOG(ERROR) << "Input tensor resize failed."; std::cout << "Input tensor resize failed."; diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.h b/mindspore/lite/tools/benchmark/benchmark_unified_api.h index 9f9a170cb91..4df28192668 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.h +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.h @@ -62,6 +62,8 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase { float *total_cosine_distance, int *total_size); void InitContext(const std::shared_ptr &context); + int CompileGraph(ModelType model_type, const std::shared_ptr &context, const std::string &model_name); + #ifdef ENABLE_OPENGL_TEXTURE int GenerateGLTexture(std::map *inputGlTexture); diff --git a/mindspore/lite/tools/common/meta_graph_serializer.cc b/mindspore/lite/tools/common/meta_graph_serializer.cc index 14c7e8636ea..ffbf0f49104 100644 --- a/mindspore/lite/tools/common/meta_graph_serializer.cc +++ b/mindspore/lite/tools/common/meta_graph_serializer.cc @@ -206,7 +206,8 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT return true; } -bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT) { +bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, + const size_t key_len, const std::string &enc_mode) { // serialize model flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); auto offset = schema::MetaGraph::Pack(builder, &meta_graphT); @@ -214,7 +215,7 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); auto content = builder.GetBufferPointer(); - if (!SerializeModel(content, size)) { + if (!SerializeModel(content, size, key, key_len, enc_mode)) { MS_LOG(ERROR) << "Serialize graph failed"; return false; } @@ -238,7 +239,8 @@ bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT return true; } -int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path) { +int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key, + const size_t key_len, const std::string &enc_mode) { flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); auto offset = schema::MetaGraph::Pack(builder, &graph); builder.Finish(offset); @@ -255,7 +257,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string return RET_ERROR; } if (save_together) { - if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size)) { + if (!meta_graph_serializer.SerializeModel(builder.GetBufferPointer(), size, key, key_len, enc_mode)) { MS_LOG(ERROR) << "Serialize graph failed"; return RET_ERROR; } @@ -264,7 +266,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string MS_LOG(ERROR) << "Serialize graph weight failed"; return RET_ERROR; } - if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph)) { + if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) { MS_LOG(ERROR) << "Serialize graph and adjust weight failed"; return RET_ERROR; } @@ -283,14 +285,25 @@ MetaGraphSerializer::~MetaGraphSerializer() { } } -bool MetaGraphSerializer::SerializeModel(const void *content, size_t size) { +bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len, + const std::string &enc_mode) { MS_ASSERT(model_fs_ != nullptr); if (size == 0 || content == nullptr) { MS_LOG(ERROR) << "Input meta graph buffer is nullptr"; return false; } - - model_fs_->write((const char *)content, static_cast(size)); + if (key_len > 0) { + size_t encrypt_len; + auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast(content), size, key, key_len, enc_mode); + if (encrypt_content == nullptr || encrypt_len == 0) { + MS_LOG(ERROR) << "Encrypt failed."; + model_fs_->close(); + return RET_ERROR; + } + model_fs_->write(reinterpret_cast(encrypt_content.get()), encrypt_len); + } else { + model_fs_->write((const char *)content, static_cast(size)); + } if (model_fs_->bad()) { MS_LOG(ERROR) << "Write model file failed: " << save_model_path_; return RET_ERROR; diff --git a/mindspore/lite/tools/common/meta_graph_serializer.h b/mindspore/lite/tools/common/meta_graph_serializer.h index b0ef4e4c6f1..9f6c42cd090 100644 --- a/mindspore/lite/tools/common/meta_graph_serializer.h +++ b/mindspore/lite/tools/common/meta_graph_serializer.h @@ -21,12 +21,14 @@ #include #include "flatbuffers/flatbuffers.h" #include "schema/inner/model_generated.h" +#include "utils/crypto.h" namespace mindspore::lite { class MetaGraphSerializer { public: // save serialized fb model - static int Save(const schema::MetaGraphT &graph, const std::string &output_path); + static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {}, + const size_t key_len = 0, const std::string &enc_mode = ""); private: MetaGraphSerializer() = default; @@ -41,9 +43,11 @@ class MetaGraphSerializer { bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph); - bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT); + bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len, + const std::string &enc_mode); - bool SerializeModel(const void *content, size_t size); + bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len, + const std::string &enc_mode); private: int64_t cur_offset_ = 0; diff --git a/mindspore/lite/tools/common/string_util.cc b/mindspore/lite/tools/common/string_util.cc index 2019a9bdc1a..7154aa84ec1 100644 --- a/mindspore/lite/tools/common/string_util.cc +++ b/mindspore/lite/tools/common/string_util.cc @@ -18,6 +18,7 @@ #include #include #include +#include namespace mindspore { namespace lite { @@ -126,5 +127,42 @@ bool ConvertDoubleVector(const std::string &str, std::vector *value) { } return true; } + +size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len) { + std::regex r("[0-9a-fA-F]+"); + if (!std::regex_match(hex_str, r)) { + MS_LOG(ERROR) << "Some characters of dec_key not in [0-9a-fA-F]"; + return 0; + } + if (hex_str.size() % 2 == 1) { // Mod 2 determines whether it is odd + MS_LOG(ERROR) << "the hexadecimal dec_key length must be even"; + return 0; + } + size_t byte_len = hex_str.size() / 2; // Two hexadecimal characters represent a byte + if (byte_len > max_len) { + MS_LOG(ERROR) << "the hexadecimal dec_key length exceeds the maximum limit: " << max_len; + return 0; + } + constexpr int32_t a_val = 10; // The value of 'A' in hexadecimal is 10 + constexpr size_t half_byte_offset = 4; + for (size_t i = 0; i < byte_len; ++i) { + size_t p = i * 2; // The i-th byte is represented by the 2*i and 2*i+1 hexadecimal characters + if (hex_str[p] >= 'a' && hex_str[p] <= 'f') { + byte_array[i] = hex_str[p] - 'a' + a_val; + } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') { + byte_array[i] = hex_str[p] - 'A' + a_val; + } else { + byte_array[i] = hex_str[p] - '0'; + } + if (hex_str[p + 1] >= 'a' && hex_str[p + 1] <= 'f') { + byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'a' + a_val); + } else if (hex_str[p] >= 'A' && hex_str[p] <= 'F') { + byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - 'A' + a_val); + } else { + byte_array[i] = (byte_array[i] << half_byte_offset) | (hex_str[p + 1] - '0'); + } + } + return byte_len; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/string_util.h b/mindspore/lite/tools/common/string_util.h index e3fc6553a6c..6f6394f6153 100644 --- a/mindspore/lite/tools/common/string_util.h +++ b/mindspore/lite/tools/common/string_util.h @@ -40,6 +40,8 @@ bool ConvertDoubleNum(const std::string &str, double *value); bool ConvertBool(std::string str, bool *value); bool ConvertDoubleVector(const std::string &str, std::vector *value); + +size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_TOOLS_STRING_UTIL_H_ diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index 3f124865718..2b4a7110f64 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -343,4 +343,32 @@ int GenerateRandomData(mindspore::tensor::MSTensor *tensor) { } return RET_OK; } + +int GenerateRandomData(mindspore::MSTensor *tensor) { + MS_ASSERT(tensor != nullptr); + auto input_data = tensor->MutableData(); + if (input_data == nullptr) { + MS_LOG(ERROR) << "MallocData for inTensor failed"; + return RET_ERROR; + } + int status = RET_ERROR; + if (static_cast(tensor->DataType()) == kObjectTypeString) { + MSTensor *input = MSTensor::StringsToTensor(tensor->Name(), {"you're the best."}); + if (input == nullptr) { + std::cerr << "StringsToTensor failed" << std::endl; + MS_LOG(ERROR) << "StringsToTensor failed"; + return RET_ERROR; + } + *tensor = *input; + delete input; + } else { + status = GenerateRandomData(tensor->DataSize(), input_data, static_cast(tensor->DataType())); + } + if (status != RET_OK) { + std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl; + MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status; + return status; + } + return RET_OK; +} } // namespace mindspore::lite diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index 618fc76f505..e62d5400b73 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -78,6 +78,8 @@ std::unique_ptr CopyQuantParamT(const std::unique_ptr diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 338fee8d6c4..3bf469e4330 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -1,7 +1,9 @@ add_definitions(-DPRIMITIVE_WRITEABLE) add_definitions(-DUSE_GLOG) set(USE_GLOG on) - +if(MSLITE_ENABLE_MODEL_ENCRYPTION) + add_compile_definitions(ENABLE_OPENSSL) +endif() set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(CCSRC_SRC @@ -13,8 +15,8 @@ set(CCSRC_SRC include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel) if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden") endif() file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} @@ -50,11 +52,10 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/import/mindir_control_flow_adjust.cc ${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc - ${SRC_DIR}/common/quant_utils.cc ${SRC_DIR}/common/dynamic_library_loader.cc ${SRC_DIR}/train/train_populate_parameter.cc - + ${SRC_DIR}/common/config_file.cc ../optimizer/*.cc ) @@ -76,16 +77,20 @@ add_subdirectory(micro/coder) if(MSLITE_ENABLE_ACL) set(MODE_ASCEND_ACL ON) + include_directories(${TOP_DIR}/graphengine/inc/external) include(${TOP_DIR}/cmake/dependency_graphengine.cmake) add_subdirectory(adapter/acl) link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) endif() -set(API_SRC ${SRC_DIR}/cxx_api/context.cc) -if(MSLITE_ENABLE_ACL) - list(APPEND API_SRC ${SRC_DIR}/cxx_api/kernel.cc) -endif() +file(GLOB CXX_API_SRCS + ${SRC_DIR}/cxx_api/*.cc + ${SRC_DIR}/cxx_api/model/*.cc + ${SRC_DIR}/cxx_api/graph/*.cc + ${SRC_DIR}/cxx_api/tensor/*.cc) + set(LITE_SRC ${API_SRC} + ${CXX_API_SRCS} ${SRC_DIR}/ops/ops_def.cc ${SRC_DIR}/ops/ops_utils.cc ${SRC_DIR}/common/utils.cc @@ -97,6 +102,7 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/common/log.cc ${SRC_DIR}/common/prim_util.cc ${SRC_DIR}/common/tensor_util.cc + ${SRC_DIR}/common/decrypt.cc ${SRC_DIR}/runtime/allocator.cc ${SRC_DIR}/runtime/inner_allocator.cc ${SRC_DIR}/runtime/runtime_allocator.cc diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index e9e74ced6b3..0f64f938783 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -33,9 +33,15 @@ #include "tools/converter/import/mindspore_importer.h" #include "nnacl/op_base.h" #include "tools/converter/micro/coder/coder.h" +#include "src/common/prim_util.h" +#include "src/common/version_manager.h" +#include "tools/common/tensor_util.h" +#include "include/api/model.h" + namespace mindspore { namespace lite { namespace { +constexpr size_t kMaxNum1024 = 1024; void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) { MS_ASSERT(converter_parameters != nullptr); converter_parameters->fmk = flag.fmk; @@ -178,6 +184,90 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptrnodes) { + auto prim = ConvertToPrimitive(node->primitive.get(), &fbb); + if (prim == nullptr) { + MS_LOG(ERROR) << "get primitive failed."; + fbb.Clear(); + return RET_ERROR; + } + if (IsCustomNode(prim, static_cast(SCHEMA_CUR))) { + *exist_custom_nodes = true; + break; + } + } + fbb.Clear(); + return RET_OK; +} + +int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr &flags) { + if (flags->trainModel) { + MS_LOG(WARNING) << "train model dont support pre-infer."; + return RET_OK; + } + + bool exist_custom_nodes = false; + auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes); + if (check_ret == RET_ERROR) { + MS_LOG(ERROR) << "CheckExistCustomOps failed."; + return RET_ERROR; + } + if (exist_custom_nodes) { + MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer."; + return RET_OK; + } + mindspore::Model model; + flatbuffers::FlatBufferBuilder builder(kMaxNum1024); + auto offset = schema::MetaGraph::Pack(builder, &meta_graph); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + int size = builder.GetSize(); + auto content = builder.GetBufferPointer(); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer nullptr"; + return RET_ERROR; + } + auto context = std::make_shared(); + if (context == nullptr) { + MS_LOG(ERROR) << "New context failed while running "; + std::cerr << "New context failed while running " << std::endl; + return RET_ERROR; + } + std::shared_ptr device_info = std::make_shared(); + auto &device_list = context->MutableDeviceInfo(); + device_list.push_back(device_info); + + auto ret = model.Build(content, size, kMindIR, context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Build error "; + std::cerr << "Build error " << std::endl; + return RET_ERROR; + } + for (auto &tensor : model.GetInputs()) { + if (tensor.Shape().empty() || tensor.DataSize() <= 0 || + std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) { + MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer."; + return RET_OK; + } + auto status = GenerateRandomData(&tensor); + if (status != RET_OK) { + MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed."; + return status; + } + } + std::vector outputs; + ret = model.Predict(model.GetInputs(), &outputs); + if (ret != kSuccess) { + MS_LOG(ERROR) << "Inference error "; + std::cerr << "Inference error " << std::endl; + return RET_ERROR; + } + return RET_OK; +} + int RunConverter(int argc, const char **argv) { std::ostringstream oss; auto flags = std::make_unique(); @@ -215,6 +305,18 @@ int RunConverter(int argc, const char **argv) { // save graph to file meta_graph->version = Version(); + if (flags->infer) { + status = PreInference(*meta_graph, flags); + if (status != RET_OK) { + oss.clear(); + oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status); + MS_LOG(ERROR) << oss.str(); + std::cout << oss.str() << std::endl; + delete meta_graph; + return status; + } + } + if (flags->microParam.enable_micro) { status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode, flags->microParam.target, flags->microParam.support_parallel, @@ -228,7 +330,7 @@ int RunConverter(int argc, const char **argv) { return status; } } else { - status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile); + status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode); if (status != RET_OK) { delete meta_graph; oss.clear(); @@ -238,7 +340,12 @@ int RunConverter(int argc, const char **argv) { return status; } } - + // clear key + status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen); + if (status != EOK) { + MS_LOG(ERROR) << "memset failed."; + return RET_ERROR; + } delete meta_graph; oss.clear(); oss << "CONVERT RESULT SUCCESS:" << status; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index e4cfafb3fde..3b18495b6a9 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -83,6 +83,20 @@ Flags::Flags() { ""); AddFlag(&Flags::graphInputFormatStr, "inputDataFormat", "Assign the input format of exported model. Only Valid for 4-dimensional input. NHWC | NCHW", "NHWC"); +#ifdef ENABLE_OPENSSL + AddFlag(&Flags::encryptionStr, "encryption", + "Whether to export the encryption model." + "true | false", + "true"); + AddFlag(&Flags::encKeyStr, "encryptKey", + "The key used to encrypt the file, expressed in hexadecimal characters. Only support AES-GCM and the key " + "length is 16.", + ""); +#endif + AddFlag(&Flags::inferStr, "infer", + "Whether to do pre-inference after convert." + "true | false", + "false"); } int Flags::InitInputOutputDataType() { @@ -310,8 +324,56 @@ int Flags::InitConfigFile() { return RET_OK; } -int Flags::Init(int argc, const char **argv) { - int ret; +int Flags::InitSaveFP16() { + if (saveFP16Str == "on") { + saveFP16 = true; + } else if (saveFP16Str == "off") { + saveFP16 = false; + } else { + std::cerr << "Init save_fp16 failed." << std::endl; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} + +int Flags::InitPreInference() { + if (this->inferStr == "true") { + this->infer = true; + } else if (this->inferStr == "false") { + this->infer = false; + } else { + std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} + +int Flags::InitEncrypt() { + if (this->encryptionStr == "true") { + this->encryption = true; + } else if (this->encryptionStr == "false") { + this->encryption = false; + } else { + std::cerr << "INPUT ILLEGAL: encryption must be true|false " << std::endl; + return RET_INPUT_PARAM_INVALID; + } + if (this->encryption) { + if (encKeyStr.empty()) { + MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false."; + return RET_INPUT_PARAM_INVALID; + } + keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen); + if (keyLen != kEncMaxLen) { + MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters " + << " and only support AES-GCM method and the key length is 16."; + return RET_INPUT_PARAM_INVALID; + } + encKeyStr.clear(); + } + return RET_OK; +} + +int Flags::PreInit(int argc, const char **argv) { if (argc == 1) { std::cout << this->Usage() << std::endl; return lite::RET_SUCCESS_EXIT; @@ -353,19 +415,23 @@ int Flags::Init(int argc, const char **argv) { } if (!this->configFile.empty()) { - ret = InitConfigFile(); + auto ret = InitConfigFile(); if (ret != RET_OK) { std::cerr << "Init config file failed." << std::endl; return RET_INPUT_PARAM_INVALID; } } + return RET_OK; +} - if (saveFP16Str == "on") { - saveFP16 = true; - } else if (saveFP16Str == "off") { - saveFP16 = false; - } else { - std::cerr << "Init save_fp16 failed." << std::endl; +int Flags::Init(int argc, const char **argv) { + auto ret = PreInit(argc, argv); + if (ret != RET_OK) { + return ret; + } + ret = InitSaveFP16(); + if (ret != RET_OK) { + std::cerr << "Init save fp16 failed." << std::endl; return RET_INPUT_PARAM_INVALID; } @@ -398,8 +464,25 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "Init graph input format failed." << std::endl; return RET_INPUT_PARAM_INVALID; } + + ret = InitEncrypt(); + if (ret != RET_OK) { + std::cerr << "Init encrypt failed." << std::endl; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitPreInference(); + if (ret != RET_OK) { + std::cerr << "Init pre inference failed." << std::endl; + return RET_INPUT_PARAM_INVALID; + } return RET_OK; } +Flags::~Flags() { + dec_key.clear(); + encKeyStr.clear(); + memset(encKey, 0, kEncMaxLen); +} bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) { // device: [device0 device1] ---> {cpu, gpu} diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index b98e2d2b3ec..a023f67be7f 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -40,6 +40,7 @@ constexpr auto kMaxSplitRatio = 10; constexpr auto kComputeRate = "computeRate"; constexpr auto kSplitDevice0 = "device0"; constexpr auto kSplitDevice1 = "device1"; +constexpr size_t kEncMaxLen = 16; struct ParallelSplitConfig { ParallelSplitType parallel_split_type_ = SplitNo; std::vector parallel_compute_rates_; @@ -50,7 +51,7 @@ class Flags : public virtual mindspore::lite::FlagParser { public: Flags(); - ~Flags() override = default; + ~Flags() override; int InitInputOutputDataType(); @@ -66,8 +67,16 @@ class Flags : public virtual mindspore::lite::FlagParser { int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser); + int InitEncrypt(); + + int InitPreInference(); + + int InitSaveFP16(); + int Init(int argc, const char **argv); + int PreInit(int argc, const char **argv); + std::string modelFile; std::string outputFile; std::string fmkIn; @@ -91,7 +100,19 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string graphInputFormatStr; std::string device; mindspore::Format graphInputFormat = mindspore::NHWC; - bool enable_micro = false; + std::string encKeyStr; + std::string encMode = "AES-GCM"; + std::string inferStr; +#ifdef ENABLE_OPENSSL + std::string encryptionStr = "true"; + bool encryption = true; +#else + std::string encryptionStr = "false"; + bool encryption = false; +#endif + bool infer = false; + unsigned char encKey[kEncMaxLen]; + size_t keyLen = 0; lite::quant::CommonQuantParam commonQuantParam; lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; diff --git a/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir b/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir index cb5530033e1..5dc0d8d8b21 100644 Binary files a/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir and b/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir differ