diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 88f19136886..79e1200cbba 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -129,6 +129,10 @@ MSTensor Model::GetOutputByTensorName(const std::vector &tensor_name) { return MSTensor(nullptr); } +std::vector Model::GetOutputsByNodeName(const std::vector &node_name) { + return std::vector{GetOutputByTensorName(node_name)}; +} + Model::Model() : impl_(nullptr) {} Model::~Model() {} diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index 381bbbceeb7..8c014263087 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -22,13 +22,8 @@ #include "utils/crypto.h" namespace mindspore { -static Buffer ReadFile(const std::string &file) { - Buffer buffer; - if (file.empty()) { - MS_LOG(ERROR) << "Pointer file is nullptr"; - return buffer; - } - +static Status RealPath(const std::string &file, std::string *realpath_str) { + MS_EXCEPTION_IF_NULL(realpath_str); char real_path_mem[PATH_MAX] = {0}; char *real_path_ret = nullptr; #if defined(_WIN32) || defined(_WIN64) @@ -37,19 +32,34 @@ static Buffer ReadFile(const std::string &file) { real_path_ret = realpath(common::SafeCStr(file), real_path_mem); #endif if (real_path_ret == nullptr) { - MS_LOG(ERROR) << "File: " << file << " is not exist."; + return Status(kMEInvalidInput, "File: " + file + " does not exist."); + } + *realpath_str = real_path_mem; + return kSuccess; +} + +static Buffer ReadFile(const std::string &file) { + Buffer buffer; + if (file.empty()) { + MS_LOG(ERROR) << "Pointer file is nullptr"; + return buffer; + } + + std::string real_path; + auto status = RealPath(file, &real_path); + if (status != kSuccess) { + MS_LOG(ERROR) << status.GetErrDescription(); return buffer; } - std::string real_path(real_path_mem); std::ifstream ifs(real_path); if (!ifs.good()) { - MS_LOG(ERROR) << "File: " << real_path << " is not exist"; + MS_LOG(ERROR) << "File: " << real_path << " does not exist"; return buffer; } if (!ifs.is_open()) { - MS_LOG(ERROR) << "File: " << real_path << "open failed"; + MS_LOG(ERROR) << "File: " << real_path << " open failed"; return buffer; } @@ -139,7 +149,13 @@ Status Serialization::Load(const std::vector &file, ModelType model_type, return Status(kMEInvalidInput, err_msg.str()); } - std::string file_path = CharToString(file); + std::string file_path; + auto status = RealPath(CharToString(file), &file_path); + if (status != kSuccess) { + MS_LOG(ERROR) << status.GetErrDescription(); + return status; + } + if (model_type == kMindIR) { FuncGraphPtr anf_graph; if (dec_key.len > dec_key.max_key_len) { @@ -193,7 +209,17 @@ Status Serialization::Load(const std::vector> &files, ModelTyp return ret; } - std::vector files_path = VectorCharToString(files); + std::vector files_path; + for (const auto &file : files) { + std::string file_path; + auto status = RealPath(CharToString(file), &file_path); + if (status != kSuccess) { + MS_LOG(ERROR) << status.GetErrDescription(); + return status; + } + files_path.emplace_back(std::move(file_path)); + } + if (model_type == kMindIR) { if (dec_key.len > dec_key.max_key_len) { err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len; diff --git a/tests/ut/cpp/cxx_api/serialization_test.cc b/tests/ut/cpp/cxx_api/serialization_test.cc new file mode 100644 index 00000000000..61cbad279b2 --- /dev/null +++ b/tests/ut/cpp/cxx_api/serialization_test.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/common_test.h" +#include "include/api/serialization.h" + +namespace mindspore { +class TestCxxApiSerialization : public UT::Common { + public: + TestCxxApiSerialization() = default; +}; + +TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_SUCCESS) { + Graph graph; + ASSERT_TRUE(Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph) == kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_output_args_nullptr_FAILED) { + auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, nullptr); + ASSERT_TRUE(status != kSuccess); + auto err_mst = status.GetErrDescription(); + ASSERT_TRUE(err_mst.find("null") != std::string::npos); +} + +TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) { + Graph graph; + auto status = Serialization::Load("./data/mindir/file_not_exist.mindir", ModelType::kMindIR, &graph); + ASSERT_TRUE(status != kSuccess); + auto err_mst = status.GetErrDescription(); + ASSERT_TRUE(err_mst.find("exist") != std::string::npos); +} + +TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) { + Graph graph; + std::string key_str = "0123456789ABCDEF"; + Key key; + memcpy(key.key, key_str.c_str(), key_str.size()); + key.len = key_str.size(); + ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph, + key, "AES-GCM") == kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) { + Graph graph; + auto status = + Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph); + ASSERT_TRUE(status != kSuccess); + auto err_mst = status.GetErrDescription(); + ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos); +} + +TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) { + Graph graph; + std::string key_str = "WRONG_KEY"; + Key key; + memcpy(key.key, key_str.c_str(), key_str.size()); + key.len = key_str.size(); + auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph, + key, "AES-GCM"); + ASSERT_TRUE(status != kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) { + Graph graph; + std::string key_str = "WRONG_KEY"; + Key key; + memcpy(key.key, key_str.c_str(), key_str.size()); + key.len = key_str.size(); + auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM"); + ASSERT_TRUE(status != kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x1_SUCCESS) { + std::vector graphs; + ASSERT_TRUE(Serialization::Load(std::vector(1, "./data/mindir/add_no_encrpty.mindir"), + ModelType::kMindIR, &graphs) == kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x2_SUCCESS) { + std::vector graphs; + ASSERT_TRUE(Serialization::Load(std::vector(2, "./data/mindir/add_no_encrpty.mindir"), + ModelType::kMindIR, &graphs) == kSuccess); +} + +TEST_F(TestCxxApiSerialization, test_load_file_not_exist_x2_FAILED) { + std::vector graphs; + auto status = Serialization::Load(std::vector(2, "./data/mindir/file_not_exist.mindir"), + ModelType::kMindIR, &graphs); + ASSERT_TRUE(status != kSuccess); + auto err_mst = status.GetErrDescription(); + ASSERT_TRUE(err_mst.find("exist") != std::string::npos); +} + +TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_x2_FAILED) { + std::vector graphs; + auto status = Serialization::Load( + std::vector(2, "./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir"), ModelType::kMindIR, &graphs); + ASSERT_TRUE(status != kSuccess); + auto err_mst = status.GetErrDescription(); + ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos); +} +} // namespace mindspore diff --git a/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir b/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir new file mode 100644 index 00000000000..cb5530033e1 Binary files /dev/null and b/tests/ut/data/mindir/add_encrpty_key_0123456789ABCDEF.mindir differ diff --git a/tests/ut/data/mindir/add_no_encrpty.mindir b/tests/ut/data/mindir/add_no_encrpty.mindir new file mode 100644 index 00000000000..c977119116c Binary files /dev/null and b/tests/ut/data/mindir/add_no_encrpty.mindir differ