!18873 add serialization cpp ut

Merge pull request !18873 from zhoufeng/cpp-api-ut
This commit is contained in:
i-robot 2021-06-26 10:50:25 +00:00 committed by Gitee
commit 1d85a97a47
5 changed files with 158 additions and 13 deletions

View File

@ -129,6 +129,10 @@ MSTensor Model::GetOutputByTensorName(const std::vector<char> &tensor_name) {
return MSTensor(nullptr);
}
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
return std::vector<MSTensor>{GetOutputByTensorName(node_name)};
}
Model::Model() : impl_(nullptr) {}
Model::~Model() {}

View File

@ -22,13 +22,8 @@
#include "utils/crypto.h"
namespace mindspore {
static Buffer ReadFile(const std::string &file) {
Buffer buffer;
if (file.empty()) {
MS_LOG(ERROR) << "Pointer file is nullptr";
return buffer;
}
static Status RealPath(const std::string &file, std::string *realpath_str) {
MS_EXCEPTION_IF_NULL(realpath_str);
char real_path_mem[PATH_MAX] = {0};
char *real_path_ret = nullptr;
#if defined(_WIN32) || defined(_WIN64)
@ -37,19 +32,34 @@ static Buffer ReadFile(const std::string &file) {
real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
#endif
if (real_path_ret == nullptr) {
MS_LOG(ERROR) << "File: " << file << " is not exist.";
return Status(kMEInvalidInput, "File: " + file + " does not exist.");
}
*realpath_str = real_path_mem;
return kSuccess;
}
static Buffer ReadFile(const std::string &file) {
Buffer buffer;
if (file.empty()) {
MS_LOG(ERROR) << "Pointer file is nullptr";
return buffer;
}
std::string real_path;
auto status = RealPath(file, &real_path);
if (status != kSuccess) {
MS_LOG(ERROR) << status.GetErrDescription();
return buffer;
}
std::string real_path(real_path_mem);
std::ifstream ifs(real_path);
if (!ifs.good()) {
MS_LOG(ERROR) << "File: " << real_path << " is not exist";
MS_LOG(ERROR) << "File: " << real_path << " does not exist";
return buffer;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "File: " << real_path << "open failed";
MS_LOG(ERROR) << "File: " << real_path << " open failed";
return buffer;
}
@ -139,7 +149,13 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
return Status(kMEInvalidInput, err_msg.str());
}
std::string file_path = CharToString(file);
std::string file_path;
auto status = RealPath(CharToString(file), &file_path);
if (status != kSuccess) {
MS_LOG(ERROR) << status.GetErrDescription();
return status;
}
if (model_type == kMindIR) {
FuncGraphPtr anf_graph;
if (dec_key.len > dec_key.max_key_len) {
@ -193,7 +209,17 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
return ret;
}
std::vector<std::string> files_path = VectorCharToString(files);
std::vector<std::string> files_path;
for (const auto &file : files) {
std::string file_path;
auto status = RealPath(CharToString(file), &file_path);
if (status != kSuccess) {
MS_LOG(ERROR) << status.GetErrDescription();
return status;
}
files_path.emplace_back(std::move(file_path));
}
if (model_type == kMindIR) {
if (dec_key.len > dec_key.max_key_len) {
err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;

View File

@ -0,0 +1,115 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "common/common_test.h"
#include "include/api/serialization.h"
namespace mindspore {
class TestCxxApiSerialization : public UT::Common {
public:
TestCxxApiSerialization() = default;
};
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_SUCCESS) {
Graph graph;
ASSERT_TRUE(Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph) == kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_output_args_nullptr_FAILED) {
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, nullptr);
ASSERT_TRUE(status != kSuccess);
auto err_mst = status.GetErrDescription();
ASSERT_TRUE(err_mst.find("null") != std::string::npos);
}
TEST_F(TestCxxApiSerialization, test_load_file_not_exist_FAILED) {
Graph graph;
auto status = Serialization::Load("./data/mindir/file_not_exist.mindir", ModelType::kMindIR, &graph);
ASSERT_TRUE(status != kSuccess);
auto err_mst = status.GetErrDescription();
ASSERT_TRUE(err_mst.find("exist") != std::string::npos);
}
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_SUCCESS) {
Graph graph;
std::string key_str = "0123456789ABCDEF";
Key key;
memcpy(key.key, key_str.c_str(), key_str.size());
key.len = key_str.size();
ASSERT_TRUE(Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
key, "AES-GCM") == kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_FAILED) {
Graph graph;
auto status =
Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph);
ASSERT_TRUE(status != kSuccess);
auto err_mst = status.GetErrDescription();
ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos);
}
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_with_wrong_key_FAILED) {
Graph graph;
std::string key_str = "WRONG_KEY";
Key key;
memcpy(key.key, key_str.c_str(), key_str.size());
key.len = key_str.size();
auto status = Serialization::Load("./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir", ModelType::kMindIR, &graph,
key, "AES-GCM");
ASSERT_TRUE(status != kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_with_wrong_key_FAILED) {
Graph graph;
std::string key_str = "WRONG_KEY";
Key key;
memcpy(key.key, key_str.c_str(), key_str.size());
key.len = key_str.size();
auto status = Serialization::Load("./data/mindir/add_no_encrpty.mindir", ModelType::kMindIR, &graph, key, "AES-GCM");
ASSERT_TRUE(status != kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x1_SUCCESS) {
std::vector<Graph> graphs;
ASSERT_TRUE(Serialization::Load(std::vector<std::string>(1, "./data/mindir/add_no_encrpty.mindir"),
ModelType::kMindIR, &graphs) == kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_no_encrpty_mindir_x2_SUCCESS) {
std::vector<Graph> graphs;
ASSERT_TRUE(Serialization::Load(std::vector<std::string>(2, "./data/mindir/add_no_encrpty.mindir"),
ModelType::kMindIR, &graphs) == kSuccess);
}
TEST_F(TestCxxApiSerialization, test_load_file_not_exist_x2_FAILED) {
std::vector<Graph> graphs;
auto status = Serialization::Load(std::vector<std::string>(2, "./data/mindir/file_not_exist.mindir"),
ModelType::kMindIR, &graphs);
ASSERT_TRUE(status != kSuccess);
auto err_mst = status.GetErrDescription();
ASSERT_TRUE(err_mst.find("exist") != std::string::npos);
}
TEST_F(TestCxxApiSerialization, test_load_encrpty_mindir_without_key_x2_FAILED) {
std::vector<Graph> graphs;
auto status = Serialization::Load(
std::vector<std::string>(2, "./data/mindir/add_encrpty_key_0123456789ABCDEF.mindir"), ModelType::kMindIR, &graphs);
ASSERT_TRUE(status != kSuccess);
auto err_mst = status.GetErrDescription();
ASSERT_TRUE(err_mst.find("be encrypted") != std::string::npos);
}
} // namespace mindspore

Binary file not shown.