diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index ed3bdf65259..62574ee7db2 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -122,8 +122,8 @@ bool get_all_files(const std::string &dir_in, std::vector *files) { int endsWith(string s, string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; } -bool ParseModelProto(mind_ir::ModelProto *model, std::string path, const unsigned char *dec_key, const size_t key_len, - const std::string &dec_mode) { +bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const unsigned char *dec_key, + const size_t key_len, const std::string &dec_mode) { if (dec_key != nullptr) { size_t plain_len; auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode); @@ -131,7 +131,7 @@ bool ParseModelProto(mind_ir::ModelProto *model, std::string path, const unsigne MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode."; return false; } - if (!model->ParseFromArray(reinterpret_cast(plain_data.get()), plain_len)) { + if (!model->ParseFromArray(reinterpret_cast(plain_data.get()), static_cast(plain_len))) { MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode."; return false; } @@ -145,8 +145,8 @@ bool ParseModelProto(mind_ir::ModelProto *model, std::string path, const unsigne return true; } -bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigned char *dec_key, const size_t key_len, - const std::string &dec_mode) { +bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const unsigned char *dec_key, + const size_t key_len, const std::string &dec_mode) { if (dec_key != nullptr) { size_t plain_len; auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode); @@ -154,7 +154,7 @@ bool ParseGraphProto(mind_ir::GraphProto *graph, std::string path, const unsigne MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode."; return false; } - if (!graph->ParseFromArray(reinterpret_cast(plain_data.get()), plain_len)) { + if (!graph->ParseFromArray(reinterpret_cast(plain_data.get()), static_cast(plain_len))) { MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, " "dec_key or dec_mode"; return false; diff --git a/mindspore/core/utils/crypto.cc b/mindspore/core/utils/crypto.cc index 4a44080529f..4ea1e5aeffb 100644 --- a/mindspore/core/utils/crypto.cc +++ b/mindspore/core/utils/crypto.cc @@ -52,17 +52,18 @@ bool IsCipherFile(const std::string &file_path) { return false; } std::vector int_buf(sizeof(int32_t)); - fid.read(int_buf.data(), sizeof(int32_t)); + fid.read(int_buf.data(), static_cast(sizeof(int32_t))); fid.close(); auto flag = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); - return flag == MAGIC_NUM; + return static_cast(flag) == MAGIC_NUM; } bool IsCipherFile(const Byte *model_data) { + MS_EXCEPTION_IF_NULL(model_data); std::vector int_buf; int_buf.assign(model_data, model_data + sizeof(int32_t)); auto flag = ByteToInt(int_buf.data(), int_buf.size()); - return flag == MAGIC_NUM; + return static_cast(flag) == MAGIC_NUM; } #if defined(_WIN32) std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, const Byte *key, @@ -93,7 +94,9 @@ bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector< int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t)); auto cipher_len = ByteToInt(int_buf.data(), int_buf.size()); - if (iv_len <= 0 || cipher_len <= 0 || ((iv_len + sizeof(int32_t) + cipher_len + sizeof(int32_t)) != encrypt_len)) { + if (iv_len <= 0 || cipher_len <= 0 || + ((static_cast(iv_len) + sizeof(int32_t) + static_cast(cipher_len) + sizeof(int32_t)) != + encrypt_len)) { MS_LOG(ERROR) << "Failed to parse encrypt data."; return false; } @@ -108,11 +111,10 @@ bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector< bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) { std::smatch results; std::regex re("([A-Z]{3})-([A-Z]{3})"); - if (!std::regex_match(mode.c_str(), re)) { + if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) { MS_LOG(ERROR) << "Mode " << mode << " is invalid."; return false; } - std::regex_search(mode, results, re); *alg_mode = results[1]; *work_mode = results[2]; return true; @@ -171,7 +173,13 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, i MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; return nullptr; } - if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1); + if (work_mode == "CBC") { + ret = EVP_CIPHER_CTX_set_padding(ctx, 1); + if (ret != 1) { + MS_LOG(ERROR) << "EVP_CIPHER_CTX_set_padding failed"; + return nullptr; + } + } return ctx; } @@ -181,7 +189,11 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto int32_t cipher_len = 0; int32_t iv_len = AES_BLOCK_SIZE; std::vector iv(iv_len); - RAND_bytes(iv.data(), sizeof(Byte) * iv_len); + auto ret = RAND_bytes(iv.data(), iv_len); + if (ret != 1) { + MS_LOG(ERROR) << "RAND_bytes error, failed to init iv."; + return false; + } std::vector iv_cpy(iv); std::string alg_mode; @@ -197,23 +209,28 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto } std::vector cipher_data_buf(plain_data.size() + AES_BLOCK_SIZE); - auto ret_evp = EVP_EncryptUpdate(ctx, cipher_data_buf.data(), &cipher_len, plain_data.data(), plain_data.size()); + auto ret_evp = EVP_EncryptUpdate(ctx, cipher_data_buf.data(), &cipher_len, plain_data.data(), + static_cast(plain_data.size())); if (ret_evp != 1) { MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; return false; } if (work_mode == "CBC") { int32_t flen = 0; - EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen); + ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen); + if (ret_evp != 1) { + MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed"; + return false; + } cipher_len += flen; } EVP_CIPHER_CTX_free(ctx); size_t offset = 0; std::vector int_buf(sizeof(int32_t)); - *encrypt_data_len = sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len; - IntToByte(&int_buf, *encrypt_data_len); - auto ret = memcpy_s(encrypt_data, encrypt_data_buf_len, int_buf.data(), int_buf.size()); + *encrypt_data_len = sizeof(int32_t) + static_cast(iv_len) + sizeof(int32_t) + static_cast(cipher_len); + IntToByte(&int_buf, static_cast(*encrypt_data_len)); + ret = memcpy_s(encrypt_data, encrypt_data_buf_len, int_buf.data(), int_buf.size()); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; } @@ -239,7 +256,8 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto } offset += int_buf.size(); - ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, cipher_data_buf.data(), cipher_len); + ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, cipher_data_buf.data(), + static_cast(cipher_len)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; } @@ -265,7 +283,8 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; return false; } - auto ret = EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), cipher_data.size()); + auto ret = + EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast(cipher_data.size())); if (ret != 1) { MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; return false; @@ -285,6 +304,9 @@ bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, const Byte *key, size_t key_len, const std::string &enc_mode) { + MS_EXCEPTION_IF_NULL(plain_data); + MS_EXCEPTION_IF_NULL(key); + size_t block_enc_buf_len = MAX_BLOCK_SIZE + RESERVED_BYTE_PER_BLOCK; size_t encrypt_buf_len = plain_len + (plain_len + MAX_BLOCK_SIZE) / MAX_BLOCK_SIZE * RESERVED_BYTE_PER_BLOCK; std::vector int_buf(sizeof(int32_t)); @@ -298,12 +320,12 @@ std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, siz size_t block_enc_len = block_enc_buf.size(); size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset); block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size); - if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, key_len, enc_mode)) { + if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast(key_len), enc_mode)) { MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; return nullptr; } - IntToByte(&int_buf, MAGIC_NUM); + IntToByte(&int_buf, static_cast(MAGIC_NUM)); size_t capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb auto ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, int_buf.data(), sizeof(int32_t)); if (ret != 0) { @@ -324,13 +346,15 @@ std::unique_ptr Encrypt(size_t *encrypt_len, const Byte *plain_data, siz std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key, size_t key_len, const std::string &dec_mode) { + MS_EXCEPTION_IF_NULL(key); + std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary); if (!fid) { MS_LOG(ERROR) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file."; return nullptr; } fid.seekg(0, std::ios_base::end); - size_t file_size = fid.tellg(); + size_t file_size = static_cast(fid.tellg()); fid.clear(); fid.seekg(0); @@ -342,26 +366,31 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_ *decrypt_len = 0; while (static_cast(fid.tellg()) < file_size) { - fid.read(int_buf.data(), sizeof(int32_t)); + fid.read(int_buf.data(), static_cast(sizeof(int32_t))); auto cipher_flag = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); - if (cipher_flag != MAGIC_NUM) { + if (static_cast(cipher_flag) != MAGIC_NUM) { MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted"; return nullptr; } - fid.read(int_buf.data(), sizeof(int32_t)); + fid.read(int_buf.data(), static_cast(sizeof(int32_t))); auto block_size = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); - fid.read(block_buf.data(), sizeof(char) * block_size); + if (block_size < 0) { + MS_LOG(ERROR) << "The block_size read from the cipher file must be not negative, but got " << block_size; + return nullptr; + } + fid.read(block_buf.data(), static_cast(block_size)); if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast(block_buf.data()), - block_size, key, key_len, dec_mode))) { + static_cast(block_size), key, static_cast(key_len), dec_mode))) { MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; return nullptr; } size_t capacity = std::min(file_size - *decrypt_len, SECUREC_MEM_MAX_LEN); - auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), decrypt_block_len); + auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), + static_cast(decrypt_block_len)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; } - *decrypt_len += decrypt_block_len; + *decrypt_len += static_cast(decrypt_block_len); } fid.close(); return decrypt_data; @@ -369,6 +398,9 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const std::string &encrypt_ std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, size_t data_size, const Byte *key, size_t key_len, const std::string &dec_mode) { + MS_EXCEPTION_IF_NULL(model_data); + MS_EXCEPTION_IF_NULL(key); + std::vector block_buf; std::vector int_buf(sizeof(int32_t)); std::vector decrypt_block_buf(MAX_BLOCK_SIZE); @@ -381,7 +413,7 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, siz int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); offset += int_buf.size(); auto cipher_flag = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); - if (cipher_flag != MAGIC_NUM) { + if (static_cast(cipher_flag) != MAGIC_NUM) { MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted."; return nullptr; } @@ -389,19 +421,24 @@ std::unique_ptr Decrypt(size_t *decrypt_len, const Byte *model_data, siz int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t)); offset += int_buf.size(); auto block_size = ByteToInt(reinterpret_cast(int_buf.data()), int_buf.size()); + if (block_size < 0) { + MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size; + return nullptr; + } block_buf.assign(model_data + offset, model_data + offset + block_size); offset += block_buf.size(); if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast(block_buf.data()), - block_buf.size(), key, key_len, dec_mode))) { + block_buf.size(), key, static_cast(key_len), dec_mode))) { MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; return nullptr; } size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN); - auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), decrypt_block_len); + auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(), + static_cast(decrypt_block_len)); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; } - *decrypt_len += decrypt_block_len; + *decrypt_len += static_cast(decrypt_block_len); } return decrypt_data; } diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc index 01b5f6c68ab..bc06c8886a4 100644 --- a/mindspore/lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc @@ -186,7 +186,10 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { return nullptr; } func_graph = LoadMindIR(flag.modelFile, false, key, key_len, flag.dec_mode); - memset(key, 0, key_len); + auto ret = memset_s(key, sizeof(key), 0, key_len); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memset_s error"; + } } else { func_graph = LoadMindIR(flag.modelFile); }