!18583 [lite]improve converter register test ut

Merge pull request !18583 from 徐安越/master_core
This commit is contained in:
i-robot 2021-06-21 14:20:02 +08:00 committed by Gitee
commit 2c3d45e34c
6 changed files with 584 additions and 30 deletions

View File

@ -14,13 +14,12 @@
* limitations under the License.
*/
#include <functional>
#include <vector>
#include "common/common_test.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/model_parser.h"
#include "ut/tools/converter/registry/model_parser_test.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags;
namespace mindspore {
class ModelParserRegistryTest : public mindspore::CommonTest {
@ -28,26 +27,27 @@ class ModelParserRegistryTest : public mindspore::CommonTest {
ModelParserRegistryTest() = default;
};
class ModelParserTest : public lite::ModelParser {
public:
ModelParserTest() = default;
};
lite::ModelParser *TestModelParserCreator() {
auto *parser = new (std::nothrow) ModelParserTest();
if (parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return parser;
}
REG_MODEL_PARSER(TEST, TestModelParserCreator);
TEST_F(ModelParserRegistryTest, TestRegistry) {
auto node_parser_reg = NodeParserTestRegistry::GetInstance();
auto add_parser = node_parser_reg->GetNodeParser("add");
ASSERT_NE(add_parser, nullptr);
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
ASSERT_NE(proposal_parser, nullptr);
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
ASSERT_NE(model_parser, nullptr);
Flags flags;
auto func_graph = model_parser->Parse(flags);
ASSERT_EQ(func_graph, nullptr);
ASSERT_NE(func_graph, nullptr);
auto node_list = func_graph->GetOrderedCnodes();
ASSERT_EQ(node_list.size(), 3);
auto iter = node_list.begin();
bool is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
ASSERT_EQ(is_add, true);
++iter;
is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
ASSERT_EQ(is_add, true);
++iter;
bool is_return = opt::CheckPrimitiveType(*iter, prim::kPrimReturn);
ASSERT_EQ(is_return, true);
}
} // namespace mindspore

View File

@ -0,0 +1,173 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/registry/model_parser_test.h"
#include <map>
#include <vector>
#include "include/errorcode.h"
#include "include/registry/model_parser_registry.h"
using mindspore::lite::ModelRegistrar;
namespace mindspore {
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
// construct funcgraph
res_graph_ = std::make_shared<FuncGraph>();
auto ret = InitOriginModelStructure();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "obtain origin model structure failed.";
return nullptr;
}
ret = BuildGraphInputs();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "build graph inputs failed.";
return nullptr;
}
ret = BuildGraphNodes();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "build graph nodes failed.";
return nullptr;
}
ret = BuildGraphOutputs();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "build graph outputs failed.";
return nullptr;
}
return res_graph_;
}
int ModelParserTest::InitOriginModelStructure() {
model_structure_ = {"add_0", "add_1"};
model_layers_info_ = {{"input", {"graph_input0"}},
{"add_0", {"graph_input0", "const_0"}},
{"add_1", {"add_0", "const_1"}},
{"output", {"add_1"}}};
return lite::RET_OK;
}
int ModelParserTest::BuildGraphInputs() {
if (model_layers_info_.find("input") == model_layers_info_.end()) {
MS_LOG(ERROR) << "model is invalid";
return lite::RET_ERROR;
}
auto inputs = model_layers_info_["input"];
for (auto &input : inputs) {
auto parameter = res_graph_->add_parameter();
if (parameter == nullptr) {
MS_LOG(ERROR) << "build parameter node failed.";
return lite::RET_ERROR;
}
ShapeVector shape{10, 10};
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, shape);
if (tensor_info == nullptr) {
return lite::RET_ERROR;
}
parameter->set_name(input);
parameter->set_abstract(tensor_info->ToAbstract());
nodes_.insert(std::make_pair(input, parameter));
}
return lite::RET_OK;
}
int ModelParserTest::BuildGraphNodes() {
if (model_structure_.empty()) {
MS_LOG(ERROR) << "model is invalid.";
return lite::RET_ERROR;
}
for (auto &node_name : model_structure_) {
if (model_layers_info_.find(node_name) == model_layers_info_.end()) {
MS_LOG(ERROR) << "model is invalid.";
return lite::RET_ERROR;
}
auto node_inputs = model_layers_info_[node_name];
if (node_inputs.empty()) {
MS_LOG(ERROR) << "model is invalid.";
return lite::RET_ERROR;
}
auto type = node_name.substr(0, node_name.find_last_of("_"));
auto node_parser = NodeParserTestRegistry::GetInstance()->GetNodeParser(type);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find current op parser.";
return lite::RET_ERROR;
}
auto primc = node_parser->Parse();
if (primc == nullptr) {
MS_LOG(ERROR) << "node parser failed.";
return lite::RET_ERROR;
}
std::vector<AnfNodePtr> anf_inputs;
for (auto &input : node_inputs) {
if (nodes_.find(input) != nodes_.end()) {
anf_inputs.push_back(nodes_[input]);
} else {
auto parameter = res_graph_->add_parameter();
if (parameter == nullptr) {
MS_LOG(ERROR) << "build parameter node failed.";
return lite::RET_ERROR;
}
ShapeVector shape{10, 10};
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, shape);
auto size = tensor_info->Size();
memset_s(tensor_info->data_c(), size, 0, size);
parameter->set_abstract(tensor_info->ToAbstract());
parameter->set_default_param(tensor_info);
parameter->set_name(input);
anf_inputs.push_back(parameter);
nodes_.insert(std::make_pair(input, parameter));
}
}
auto cnode = res_graph_->NewCNode(std::shared_ptr<ops::PrimitiveC>(primc), anf_inputs);
cnode->set_fullname_with_scope(node_name);
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, ShapeVector{});
cnode->set_abstract(tensor_info->ToAbstract());
nodes_.insert(std::make_pair(node_name, cnode));
}
return lite::RET_OK;
}
int ModelParserTest::BuildGraphOutputs() {
if (model_layers_info_.find("output") == model_layers_info_.end()) {
MS_LOG(ERROR) << "model is invalid.";
return lite::RET_ERROR;
}
auto outputs = model_layers_info_["output"];
if (outputs.empty()) {
MS_LOG(ERROR) << "odel is invalid.";
return lite::RET_ERROR;
}
if (outputs.size() > 1) {
// need generate a MakeTuple to package outputs.
} else {
if (nodes_.find(outputs[0]) == nodes_.end()) {
return lite::RET_ERROR;
}
auto return_prim = std::make_shared<Primitive>("Return");
auto return_cnode = res_graph_->NewCNode(return_prim, {nodes_[outputs[0]]});
return_cnode->set_fullname_with_scope("Return");
res_graph_->set_return(return_cnode);
}
return lite::RET_OK;
}
lite::ModelParser *TestModelParserCreator() {
auto *model_parser = new (std::nothrow) ModelParserTest();
if (model_parser == nullptr) {
MS_LOG(ERROR) << "new model parser failed";
return nullptr;
}
return model_parser;
}
REG_MODEL_PARSER(TEST, TestModelParserCreator);
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
#define LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
#include <map>
#include <string>
#include <vector>
#include "include/registry/model_parser_registry.h"
#include "ut/tools/converter/registry/node_parser_test.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
class ModelParserTest : public lite::ModelParser {
public:
ModelParserTest() = default;
FuncGraphPtr Parse(const lite::converter::Flags &flag) override;
private:
int InitOriginModelStructure();
int BuildGraphInputs();
int BuildGraphNodes();
int BuildGraphOutputs();
std::map<std::string, AnfNodePtr> nodes_;
std::map<std::string, std::vector<std::string>> model_layers_info_;
std::vector<std::string> model_structure_;
};
} // namespace mindspore
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H

View File

@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/tools/converter/registry/node_parser_test.h"
#include <map>
#include <string>
#include <vector>
#include "ops/fusion/add_fusion.h"
#include "ops/split.h"
#include "ops/concat.h"
#include "ops/custom.h"
namespace mindspore {
class AddNodeParserTest : public NodeParserTest {
public:
AddNodeParserTest() = default;
~AddNodeParserTest() = default;
ops::PrimitiveC *Parse() override {
auto primc = std::make_unique<ops::AddFusion>();
return primc.release();
}
};
class SplitNodeParserTest : public NodeParserTest {
public:
SplitNodeParserTest() = default;
~SplitNodeParserTest() = default;
ops::PrimitiveC *Parse() override {
auto primc = std::make_unique<ops::Split>();
primc->set_axis(0);
primc->set_output_num(2);
return primc.release();
}
};
class ConcatNodeParserTest : public NodeParserTest {
public:
ConcatNodeParserTest() = default;
~ConcatNodeParserTest() = default;
ops::PrimitiveC *Parse() override {
auto primc = std::make_unique<ops::Concat>();
primc->set_axis(0);
return primc.release();
}
};
// hypothesize custom op called proposal has these attrs : ["image_height", "image_width"].
class CustomProposalNodeParserTest : public NodeParserTest {
public:
CustomProposalNodeParserTest() = default;
~CustomProposalNodeParserTest() = default;
ops::PrimitiveC *Parse() override {
auto primc = std::make_unique<ops::Custom>();
primc->set_type("Proposal");
std::map<std::string, std::vector<uint8_t>> custom_attrs;
std::string height = std::to_string(100);
std::vector<uint8_t> height_attr(height.begin(), height.end());
custom_attrs["image_height"] = height_attr;
std::string width = std::to_string(200);
std::vector<uint8_t> width_attr(width.begin(), width.end());
custom_attrs["image_width"] = width_attr;
primc->set_attr(custom_attrs);
return primc.release();
}
};
constexpr auto kAdd = "add";
constexpr auto kSplit = "split";
constexpr auto kConcat = "concat";
constexpr auto kProposal = "proposal";
REGISTER_NODE_PARSER_TEST(kAdd, std::make_shared<AddNodeParserTest>())
REGISTER_NODE_PARSER_TEST(kSplit, std::make_shared<SplitNodeParserTest>())
REGISTER_NODE_PARSER_TEST(kConcat, std::make_shared<ConcatNodeParserTest>())
REGISTER_NODE_PARSER_TEST(kProposal, std::make_shared<CustomProposalNodeParserTest>())
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H
#define LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H
#include <memory>
#include <string>
#include <unordered_map>
#include "ops/primitive_c.h"
#include "src/common/log_adapter.h"
namespace mindspore {
class NodeParserTest {
public:
NodeParserTest() = default;
virtual ~NodeParserTest() {}
virtual ops::PrimitiveC *Parse() { return nullptr; }
};
using NodeParserTestPtr = std::shared_ptr<NodeParserTest>;
class NodeParserTestRegistry {
public:
static NodeParserTestRegistry *GetInstance() {
static NodeParserTestRegistry instance;
return &instance;
}
NodeParserTestPtr GetNodeParser(const std::string &name) {
if (parsers_.find(name) == parsers_.end()) {
MS_LOG(ERROR) << "cannot find node parser.";
return nullptr;
}
return parsers_[name];
}
void RegNodeParser(const std::string &name, const NodeParserTestPtr node_parser) { parsers_[name] = node_parser; }
private:
NodeParserTestRegistry() = default;
virtual ~NodeParserTestRegistry() = default;
std::unordered_map<std::string, NodeParserTestPtr> parsers_;
};
class RegisterNodeParserTest {
public:
RegisterNodeParserTest(const std::string &name, NodeParserTestPtr node_parser) {
NodeParserTestRegistry::GetInstance()->RegNodeParser(name, node_parser);
}
};
#define REGISTER_NODE_PARSER_TEST(name, node_parser) \
static RegisterNodeParserTest g_##name##_node_parser(name, node_parser);
} // namespace mindspore
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H

View File

@ -14,40 +14,218 @@
* limitations under the License.
*/
#include <map>
#include <string>
#include <vector>
#include "common/common_test.h"
#include "backend/optimizer/common/pass.h"
#include "include/registry/model_parser_registry.h"
#include "include/registry/pass_registry.h"
#include "ops/fusion/add_fusion.h"
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/model_parser.h"
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::lite::converter::Flags;
namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest {
public:
PassRegistryTest() = default;
void SetUp() override {
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
if (model_parser == nullptr) {
return;
}
Flags flags;
func_graph_ = model_parser->Parse(flags);
}
FuncGraphPtr func_graph_ = nullptr;
};
namespace opt {
// fuse add and add to addn.
class Test1Fusion : public Pass {
public:
Test1Fusion() : Pass("test1_fusion") {}
bool CanFusion(const CNodePtr &cnode) {
if (cnode == nullptr) {
return false;
}
if (!opt::CheckPrimitiveType(cnode, prim::kPrimAddFusion)) {
return false;
}
auto primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(cnode->input(0));
if (primc == nullptr) {
return false;
}
if (primc->GetAttr(ops::kActivationType) != nullptr && primc->get_activation_type() != mindspore::NO_ACTIVATION) {
return false;
}
size_t input_cnode_num = 0;
for (size_t i = 1; i < cnode->size(); ++i) {
auto input = cnode->input(i);
if (!utils::isa<CNodePtr>(input)) {
continue;
}
if (!opt::CheckPrimitiveType(input, prim::kPrimAddFusion)) {
return false;
}
auto input_cnode = input->cast<CNodePtr>();
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(input_cnode->input(0));
if (add_primc == nullptr) {
return false;
}
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
add_primc->get_activation_type() != mindspore::NO_ACTIVATION) {
return false;
}
++input_cnode_num;
continue;
}
return input_cnode_num > 0;
}
bool Run(const FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
return false;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!CanFusion(cnode)) {
continue;
}
std::vector<AnfNodePtr> inputs;
for (size_t i = 1; i < cnode->size(); ++i) {
auto input_node = cnode->input(i);
if (!utils::isa<CNode>(input_node)) {
inputs.push_back(input_node);
continue;
}
auto input_cnode = input_node->cast<CNodePtr>();
for (size_t j = 1; j < input_cnode->size(); ++j) {
inputs.push_back(input_cnode->input(j));
}
}
auto primc = std::make_shared<ops::AddN>();
auto new_cnode = func_graph->NewCNode(primc, inputs);
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
new_cnode->set_abstract(cnode->abstract()->Clone());
manager->Replace(node, new_cnode);
}
return true;
}
};
// convert addn to custom op
class Test2Fusion : public Pass {
public:
Test2Fusion() : Pass("test2_fusion") {}
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
if (cnode == nullptr) {
return nullptr;
}
auto primc = std::make_shared<ops::Custom>();
if (primc == nullptr) {
return nullptr;
}
primc->set_type("Custom_AddN");
std::map<std::string, std::vector<uint8_t>> custom_attrs;
std::string input_num = std::to_string(3);
std::vector<uint8_t> input_num_attr(input_num.begin(), input_num.end());
custom_attrs["input_num"] = input_num_attr;
std::string op_kind = "custom op";
std::vector<uint8_t> op_kind_attr(op_kind.begin(), op_kind.end());
custom_attrs["op_kind"] = op_kind_attr;
primc->set_attr(custom_attrs);
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
auto custom_cnode = func_graph->NewCNode(primc, inputs);
custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
custom_cnode->set_abstract(cnode->abstract()->Clone());
return custom_cnode;
}
bool Run(const FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
return false;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
return false;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
if (!opt::CheckPrimitiveType(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto custome_cnode = CreateCustomOp(func_graph, cnode);
if (custome_cnode == nullptr) {
return false;
}
manager->Replace(node, custome_cnode);
}
return true;
}
};
class TestFusion : public Pass {
public:
TestFusion() : Pass("test_fusion") {}
bool Run(const FuncGraphPtr &func_graph) override { return true; }
bool Run(const FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
return false;
}
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
return false;
}
auto test1_fusion = std::make_shared<Test1Fusion>();
if (!test1_fusion->Run(func_graph)) {
return false;
}
auto test2_fusion = std::make_shared<Test2Fusion>();
if (!test2_fusion->Run(func_graph)) {
return false;
}
return true;
}
};
REG_PASS(POSITION_BEGIN, TestFusion)
REG_PASS(POSITION_END, TestFusion)
} // namespace opt
TEST_F(PassRegistryTest, TestRegistry) {
auto passes = opt::PassRegistry::GetInstance()->GetPasses();
ASSERT_EQ(passes.size(), 2);
ASSERT_EQ(passes.size(), 1);
auto begin_pass = passes[opt::POSITION_BEGIN];
ASSERT_NE(begin_pass, nullptr);
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
ASSERT_NE(begin_pass_test, nullptr);
auto res = begin_pass_test->Run(nullptr);
ASSERT_EQ(res, true);
auto end_pass = passes[opt::POSITION_END];
ASSERT_NE(end_pass, nullptr);
auto end_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(end_pass);
ASSERT_NE(end_pass_test, nullptr);
res = end_pass_test->Run(nullptr);
ASSERT_NE(func_graph_, nullptr);
auto res = begin_pass_test->Run(func_graph_);
ASSERT_EQ(res, true);
auto cnode_list = func_graph_->GetOrderedCnodes();
ASSERT_EQ(cnode_list.size(), 2);
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
ASSERT_EQ(is_custom, true);
auto custome_prim = GetValueNode<std::shared_ptr<ops::Custom>>(cnode_list.front()->input(0));
ASSERT_NE(custome_prim, nullptr);
auto type = custome_prim->get_type();
ASSERT_EQ(type, std::string("Custom_AddN"));
bool is_return = opt::CheckPrimitiveType(cnode_list.back(), prim::kPrimReturn);
ASSERT_EQ(is_return, true);
}
} // namespace mindspore