forked from mindspore-Ecosystem/mindspore
!18873 add serialization cpp ut
Merge pull request !18873 from zhoufeng/cpp-api-ut
This commit is contained in:
commit
1d85a97a47
|
@ -129,6 +129,10 @@ MSTensor Model::GetOutputByTensorName(const std::vector<char> &tensor_name) {
|
|||
return MSTensor(nullptr);
|
||||
}
|
||||
|
||||
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
|
||||
return std::vector<MSTensor>{GetOutputByTensorName(node_name)};
|
||||
}
|
||||
|
||||
Model::Model() : impl_(nullptr) {}
|
||||
Model::~Model() {}
|
||||
|
||||
|
|
|
@ -22,13 +22,8 @@
|
|||
#include "utils/crypto.h"
|
||||
|
||||
namespace mindspore {
|
||||
static Buffer ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "Pointer file is nullptr";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
static Status RealPath(const std::string &file, std::string *realpath_str) {
|
||||
MS_EXCEPTION_IF_NULL(realpath_str);
|
||||
char real_path_mem[PATH_MAX] = {0};
|
||||
char *real_path_ret = nullptr;
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
|
@ -37,19 +32,34 @@ static Buffer ReadFile(const std::string &file) {
|
|||
real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
|
||||
#endif
|
||||
if (real_path_ret == nullptr) {
|
||||
MS_LOG(ERROR) << "File: " << file << " is not exist.";
|
||||
return Status(kMEInvalidInput, "File: " + file + " does not exist.");
|
||||
}
|
||||
*realpath_str = real_path_mem;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
static Buffer ReadFile(const std::string &file) {
|
||||
Buffer buffer;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "Pointer file is nullptr";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::string real_path;
|
||||
auto status = RealPath(file, &real_path);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << status.GetErrDescription();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::string real_path(real_path_mem);
|
||||
std::ifstream ifs(real_path);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "File: " << real_path << " is not exist";
|
||||
MS_LOG(ERROR) << "File: " << real_path << " does not exist";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "File: " << real_path << "open failed";
|
||||
MS_LOG(ERROR) << "File: " << real_path << " open failed";
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
@ -139,7 +149,13 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
|
|||
return Status(kMEInvalidInput, err_msg.str());
|
||||
}
|
||||
|
||||
std::string file_path = CharToString(file);
|
||||
std::string file_path;
|
||||
auto status = RealPath(CharToString(file), &file_path);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << status.GetErrDescription();
|
||||
return status;
|
||||
}
|
||||
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph;
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
|
@ -193,7 +209,17 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::string> files_path = VectorCharToString(files);
|
||||
std::vector<std::string> files_path;
|
||||
for (const auto &file : files) {
|
||||
std::string file_path;
|
||||
auto status = RealPath(CharToString(file), &file_path);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << status.GetErrDescription();
|
||||
return status;
|
||||
}
|
||||
files_path.emplace_back(std::move(file_path));
|
||||
}
|
||||
|
||||
if (model_type == kMindIR) {
|
||||
if (dec_key.len > dec_key.max_key_len) {
|
||||
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "include/api/serialization.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestCxxApiSerialization : public UT::Common {
|
||||
public:
|
||||
TestCxxApiSerialization() = default;
|
||||
};
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_SUCCESS) {
|
||||
Graph graph;
|
||||
ASSERT_TRUE(Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_output_args_nullptr_FAILED) {
|
||||
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, nullptr);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
auto err_mst = status.GetErrDescription();
|
||||
ASSERT_TRUE(err_mst.find("null") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) {
|
||||
Graph graph;
|
||||
auto status = Serialization::Load("./data/mindir/file_not_exist.mindir", ModelType::kMindIR, &graph);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
auto err_mst = status.GetErrDescription();
|
||||
ASSERT_TRUE(err_mst.find("exist") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) {
|
||||
Graph graph;
|
||||
std::string key_str = "0123456789ABCDEF";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||
key, "AES-GCM") == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
|
||||
Graph graph;
|
||||
auto status =
|
||||
Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
auto err_mst = status.GetErrDescription();
|
||||
ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) {
|
||||
Graph graph;
|
||||
std::string key_str = "WRONG_KEY";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
|
||||
key, "AES-GCM");
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) {
|
||||
Graph graph;
|
||||
std::string key_str = "WRONG_KEY";
|
||||
Key key;
|
||||
memcpy(key.key, key_str.c_str(), key_str.size());
|
||||
key.len = key_str.size();
|
||||
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM");
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x1_SUCCESS) {
|
||||
std::vector<Graph> graphs;
|
||||
ASSERT_TRUE(Serialization::Load(std::vector<std::string>(1, "./data/mindir/add_no_encrpty.mindir"),
|
||||
ModelType::kMindIR, &graphs) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x2_SUCCESS) {
|
||||
std::vector<Graph> graphs;
|
||||
ASSERT_TRUE(Serialization::Load(std::vector<std::string>(2, "./data/mindir/add_no_encrpty.mindir"),
|
||||
ModelType::kMindIR, &graphs) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_file_not_exist_x2_FAILED) {
|
||||
std::vector<Graph> graphs;
|
||||
auto status = Serialization::Load(std::vector<std::string>(2, "./data/mindir/file_not_exist.mindir"),
|
||||
ModelType::kMindIR, &graphs);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
auto err_mst = status.GetErrDescription();
|
||||
ASSERT_TRUE(err_mst.find("exist") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_x2_FAILED) {
|
||||
std::vector<Graph> graphs;
|
||||
auto status = Serialization::Load(
|
||||
std::vector<std::string>(2, "./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir"), ModelType::kMindIR, &graphs);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
auto err_mst = status.GetErrDescription();
|
||||
ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos);
|
||||
}
|
||||
} // namespace mindspore
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue