!18583 [lite]improve converter register test ut
Merge pull request !18583 from 徐安越/master_core
This commit is contained in:
commit
2c3d45e34c
|
@ -14,13 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
|
@ -28,26 +27,27 @@ class ModelParserRegistryTest : public mindspore::CommonTest {
|
|||
ModelParserRegistryTest() = default;
|
||||
};
|
||||
|
||||
class ModelParserTest : public lite::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
};
|
||||
|
||||
lite::ModelParser *TestModelParserCreator() {
|
||||
auto *parser = new (std::nothrow) ModelParserTest();
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
REG_MODEL_PARSER(TEST, TestModelParserCreator);
|
||||
|
||||
TEST_F(ModelParserRegistryTest, TestRegistry) {
|
||||
auto node_parser_reg = NodeParserTestRegistry::GetInstance();
|
||||
auto add_parser = node_parser_reg->GetNodeParser("add");
|
||||
ASSERT_NE(add_parser, nullptr);
|
||||
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
|
||||
ASSERT_NE(proposal_parser, nullptr);
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
Flags flags;
|
||||
auto func_graph = model_parser->Parse(flags);
|
||||
ASSERT_EQ(func_graph, nullptr);
|
||||
ASSERT_NE(func_graph, nullptr);
|
||||
auto node_list = func_graph->GetOrderedCnodes();
|
||||
ASSERT_EQ(node_list.size(), 3);
|
||||
auto iter = node_list.begin();
|
||||
bool is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
|
||||
ASSERT_EQ(is_add, true);
|
||||
++iter;
|
||||
is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
|
||||
ASSERT_EQ(is_add, true);
|
||||
++iter;
|
||||
bool is_return = opt::CheckPrimitiveType(*iter, prim::kPrimReturn);
|
||||
ASSERT_EQ(is_return, true);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
namespace mindspore {
|
||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
|
||||
// construct funcgraph
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto ret = InitOriginModelStructure();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain origin model structure failed.";
|
||||
return nullptr;
|
||||
}
|
||||
ret = BuildGraphInputs();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "build graph inputs failed.";
|
||||
return nullptr;
|
||||
}
|
||||
ret = BuildGraphNodes();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "build graph nodes failed.";
|
||||
return nullptr;
|
||||
}
|
||||
ret = BuildGraphOutputs();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "build graph outputs failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
}
|
||||
|
||||
int ModelParserTest::InitOriginModelStructure() {
|
||||
model_structure_ = {"add_0", "add_1"};
|
||||
model_layers_info_ = {{"input", {"graph_input0"}},
|
||||
{"add_0", {"graph_input0", "const_0"}},
|
||||
{"add_1", {"add_0", "const_1"}},
|
||||
{"output", {"add_1"}}};
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ModelParserTest::BuildGraphInputs() {
|
||||
if (model_layers_info_.find("input") == model_layers_info_.end()) {
|
||||
MS_LOG(ERROR) << "model is invalid";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto inputs = model_layers_info_["input"];
|
||||
for (auto &input : inputs) {
|
||||
auto parameter = res_graph_->add_parameter();
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "build parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ShapeVector shape{10, 10};
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, shape);
|
||||
if (tensor_info == nullptr) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
parameter->set_name(input);
|
||||
parameter->set_abstract(tensor_info->ToAbstract());
|
||||
nodes_.insert(std::make_pair(input, parameter));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ModelParserTest::BuildGraphNodes() {
|
||||
if (model_structure_.empty()) {
|
||||
MS_LOG(ERROR) << "model is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (auto &node_name : model_structure_) {
|
||||
if (model_layers_info_.find(node_name) == model_layers_info_.end()) {
|
||||
MS_LOG(ERROR) << "model is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto node_inputs = model_layers_info_[node_name];
|
||||
if (node_inputs.empty()) {
|
||||
MS_LOG(ERROR) << "model is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto type = node_name.substr(0, node_name.find_last_of("_"));
|
||||
auto node_parser = NodeParserTestRegistry::GetInstance()->GetNodeParser(type);
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot find current op parser.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto primc = node_parser->Parse();
|
||||
if (primc == nullptr) {
|
||||
MS_LOG(ERROR) << "node parser failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<AnfNodePtr> anf_inputs;
|
||||
for (auto &input : node_inputs) {
|
||||
if (nodes_.find(input) != nodes_.end()) {
|
||||
anf_inputs.push_back(nodes_[input]);
|
||||
} else {
|
||||
auto parameter = res_graph_->add_parameter();
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "build parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ShapeVector shape{10, 10};
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, shape);
|
||||
auto size = tensor_info->Size();
|
||||
memset_s(tensor_info->data_c(), size, 0, size);
|
||||
parameter->set_abstract(tensor_info->ToAbstract());
|
||||
parameter->set_default_param(tensor_info);
|
||||
parameter->set_name(input);
|
||||
anf_inputs.push_back(parameter);
|
||||
nodes_.insert(std::make_pair(input, parameter));
|
||||
}
|
||||
}
|
||||
auto cnode = res_graph_->NewCNode(std::shared_ptr<ops::PrimitiveC>(primc), anf_inputs);
|
||||
cnode->set_fullname_with_scope(node_name);
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(TypeId::kNumberTypeFloat32, ShapeVector{});
|
||||
cnode->set_abstract(tensor_info->ToAbstract());
|
||||
nodes_.insert(std::make_pair(node_name, cnode));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ModelParserTest::BuildGraphOutputs() {
|
||||
if (model_layers_info_.find("output") == model_layers_info_.end()) {
|
||||
MS_LOG(ERROR) << "model is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto outputs = model_layers_info_["output"];
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "odel is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (outputs.size() > 1) {
|
||||
// need generate a MakeTuple to package outputs.
|
||||
} else {
|
||||
if (nodes_.find(outputs[0]) == nodes_.end()) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto return_prim = std::make_shared<Primitive>("Return");
|
||||
auto return_cnode = res_graph_->NewCNode(return_prim, {nodes_[outputs[0]]});
|
||||
return_cnode->set_fullname_with_scope("Return");
|
||||
res_graph_->set_return(return_cnode);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
lite::ModelParser *TestModelParserCreator() {
|
||||
auto *model_parser = new (std::nothrow) ModelParserTest();
|
||||
if (model_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return model_parser;
|
||||
}
|
||||
REG_MODEL_PARSER(TEST, TestModelParserCreator);
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
|
||||
#define LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelParserTest : public lite::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
FuncGraphPtr Parse(const lite::converter::Flags &flag) override;
|
||||
|
||||
private:
|
||||
int InitOriginModelStructure();
|
||||
int BuildGraphInputs();
|
||||
int BuildGraphNodes();
|
||||
int BuildGraphOutputs();
|
||||
std::map<std::string, AnfNodePtr> nodes_;
|
||||
std::map<std::string, std::vector<std::string>> model_layers_info_;
|
||||
std::vector<std::string> model_structure_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
|
|
@ -0,0 +1,88 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ops/fusion/add_fusion.h"
|
||||
#include "ops/split.h"
|
||||
#include "ops/concat.h"
|
||||
#include "ops/custom.h"
|
||||
|
||||
namespace mindspore {
|
||||
class AddNodeParserTest : public NodeParserTest {
|
||||
public:
|
||||
AddNodeParserTest() = default;
|
||||
~AddNodeParserTest() = default;
|
||||
ops::PrimitiveC *Parse() override {
|
||||
auto primc = std::make_unique<ops::AddFusion>();
|
||||
return primc.release();
|
||||
}
|
||||
};
|
||||
|
||||
class SplitNodeParserTest : public NodeParserTest {
|
||||
public:
|
||||
SplitNodeParserTest() = default;
|
||||
~SplitNodeParserTest() = default;
|
||||
ops::PrimitiveC *Parse() override {
|
||||
auto primc = std::make_unique<ops::Split>();
|
||||
primc->set_axis(0);
|
||||
primc->set_output_num(2);
|
||||
return primc.release();
|
||||
}
|
||||
};
|
||||
|
||||
class ConcatNodeParserTest : public NodeParserTest {
|
||||
public:
|
||||
ConcatNodeParserTest() = default;
|
||||
~ConcatNodeParserTest() = default;
|
||||
ops::PrimitiveC *Parse() override {
|
||||
auto primc = std::make_unique<ops::Concat>();
|
||||
primc->set_axis(0);
|
||||
return primc.release();
|
||||
}
|
||||
};
|
||||
|
||||
// hypothesize custom op called proposal has these attrs : ["image_height", "image_width"].
|
||||
class CustomProposalNodeParserTest : public NodeParserTest {
|
||||
public:
|
||||
CustomProposalNodeParserTest() = default;
|
||||
~CustomProposalNodeParserTest() = default;
|
||||
ops::PrimitiveC *Parse() override {
|
||||
auto primc = std::make_unique<ops::Custom>();
|
||||
primc->set_type("Proposal");
|
||||
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
||||
std::string height = std::to_string(100);
|
||||
std::vector<uint8_t> height_attr(height.begin(), height.end());
|
||||
custom_attrs["image_height"] = height_attr;
|
||||
std::string width = std::to_string(200);
|
||||
std::vector<uint8_t> width_attr(width.begin(), width.end());
|
||||
custom_attrs["image_width"] = width_attr;
|
||||
primc->set_attr(custom_attrs);
|
||||
return primc.release();
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto kAdd = "add";
|
||||
constexpr auto kSplit = "split";
|
||||
constexpr auto kConcat = "concat";
|
||||
constexpr auto kProposal = "proposal";
|
||||
REGISTER_NODE_PARSER_TEST(kAdd, std::make_shared<AddNodeParserTest>())
|
||||
REGISTER_NODE_PARSER_TEST(kSplit, std::make_shared<SplitNodeParserTest>())
|
||||
REGISTER_NODE_PARSER_TEST(kConcat, std::make_shared<ConcatNodeParserTest>())
|
||||
REGISTER_NODE_PARSER_TEST(kProposal, std::make_shared<CustomProposalNodeParserTest>())
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H
|
||||
#define LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class NodeParserTest {
|
||||
public:
|
||||
NodeParserTest() = default;
|
||||
|
||||
virtual ~NodeParserTest() {}
|
||||
|
||||
virtual ops::PrimitiveC *Parse() { return nullptr; }
|
||||
};
|
||||
using NodeParserTestPtr = std::shared_ptr<NodeParserTest>;
|
||||
|
||||
class NodeParserTestRegistry {
|
||||
public:
|
||||
static NodeParserTestRegistry *GetInstance() {
|
||||
static NodeParserTestRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
NodeParserTestPtr GetNodeParser(const std::string &name) {
|
||||
if (parsers_.find(name) == parsers_.end()) {
|
||||
MS_LOG(ERROR) << "cannot find node parser.";
|
||||
return nullptr;
|
||||
}
|
||||
return parsers_[name];
|
||||
}
|
||||
|
||||
void RegNodeParser(const std::string &name, const NodeParserTestPtr node_parser) { parsers_[name] = node_parser; }
|
||||
|
||||
private:
|
||||
NodeParserTestRegistry() = default;
|
||||
virtual ~NodeParserTestRegistry() = default;
|
||||
std::unordered_map<std::string, NodeParserTestPtr> parsers_;
|
||||
};
|
||||
|
||||
class RegisterNodeParserTest {
|
||||
public:
|
||||
RegisterNodeParserTest(const std::string &name, NodeParserTestPtr node_parser) {
|
||||
NodeParserTestRegistry::GetInstance()->RegNodeParser(name, node_parser);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_NODE_PARSER_TEST(name, node_parser) \
|
||||
static RegisterNodeParserTest g_##name##_node_parser(name, node_parser);
|
||||
} // namespace mindspore
|
||||
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_NODE_PARSER_TEST_H
|
|
@ -14,40 +14,218 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "ops/fusion/add_fusion.h"
|
||||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::Flags;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
PassRegistryTest() = default;
|
||||
void SetUp() override {
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
|
||||
if (model_parser == nullptr) {
|
||||
return;
|
||||
}
|
||||
Flags flags;
|
||||
func_graph_ = model_parser->Parse(flags);
|
||||
}
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
};
|
||||
|
||||
namespace opt {
|
||||
// fuse add and add to addn.
|
||||
class Test1Fusion : public Pass {
|
||||
public:
|
||||
Test1Fusion() : Pass("test1_fusion") {}
|
||||
bool CanFusion(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimAddFusion)) {
|
||||
return false;
|
||||
}
|
||||
auto primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(cnode->input(0));
|
||||
if (primc == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (primc->GetAttr(ops::kActivationType) != nullptr && primc->get_activation_type() != mindspore::NO_ACTIVATION) {
|
||||
return false;
|
||||
}
|
||||
size_t input_cnode_num = 0;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
if (!utils::isa<CNodePtr>(input)) {
|
||||
continue;
|
||||
}
|
||||
if (!opt::CheckPrimitiveType(input, prim::kPrimAddFusion)) {
|
||||
return false;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
auto add_primc = GetValueNode<std::shared_ptr<ops::AddFusion>>(input_cnode->input(0));
|
||||
if (add_primc == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (add_primc->GetAttr(ops::kActivationType) != nullptr &&
|
||||
add_primc->get_activation_type() != mindspore::NO_ACTIVATION) {
|
||||
return false;
|
||||
}
|
||||
++input_cnode_num;
|
||||
continue;
|
||||
}
|
||||
return input_cnode_num > 0;
|
||||
}
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CanFusion(cnode)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (!utils::isa<CNode>(input_node)) {
|
||||
inputs.push_back(input_node);
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
for (size_t j = 1; j < input_cnode->size(); ++j) {
|
||||
inputs.push_back(input_cnode->input(j));
|
||||
}
|
||||
}
|
||||
auto primc = std::make_shared<ops::AddN>();
|
||||
auto new_cnode = func_graph->NewCNode(primc, inputs);
|
||||
new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
new_cnode->set_abstract(cnode->abstract()->Clone());
|
||||
manager->Replace(node, new_cnode);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// convert addn to custom op
|
||||
class Test2Fusion : public Pass {
|
||||
public:
|
||||
Test2Fusion() : Pass("test2_fusion") {}
|
||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primc = std::make_shared<ops::Custom>();
|
||||
if (primc == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
primc->set_type("Custom_AddN");
|
||||
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
||||
std::string input_num = std::to_string(3);
|
||||
std::vector<uint8_t> input_num_attr(input_num.begin(), input_num.end());
|
||||
custom_attrs["input_num"] = input_num_attr;
|
||||
std::string op_kind = "custom op";
|
||||
std::vector<uint8_t> op_kind_attr(op_kind.begin(), op_kind.end());
|
||||
custom_attrs["op_kind"] = op_kind_attr;
|
||||
primc->set_attr(custom_attrs);
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
auto custom_cnode = func_graph->NewCNode(primc, inputs);
|
||||
custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
custom_cnode->set_abstract(cnode->abstract()->Clone());
|
||||
return custom_cnode;
|
||||
}
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (!opt::CheckPrimitiveType(node, prim::kPrimAddN)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto custome_cnode = CreateCustomOp(func_graph, cnode);
|
||||
if (custome_cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
manager->Replace(node, custome_cnode);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class TestFusion : public Pass {
|
||||
public:
|
||||
TestFusion() : Pass("test_fusion") {}
|
||||
bool Run(const FuncGraphPtr &func_graph) override { return true; }
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto test1_fusion = std::make_shared<Test1Fusion>();
|
||||
if (!test1_fusion->Run(func_graph)) {
|
||||
return false;
|
||||
}
|
||||
auto test2_fusion = std::make_shared<Test2Fusion>();
|
||||
if (!test2_fusion->Run(func_graph)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
REG_PASS(POSITION_BEGIN, TestFusion)
|
||||
REG_PASS(POSITION_END, TestFusion)
|
||||
} // namespace opt
|
||||
|
||||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto passes = opt::PassRegistry::GetInstance()->GetPasses();
|
||||
ASSERT_EQ(passes.size(), 2);
|
||||
ASSERT_EQ(passes.size(), 1);
|
||||
auto begin_pass = passes[opt::POSITION_BEGIN];
|
||||
ASSERT_NE(begin_pass, nullptr);
|
||||
auto begin_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(begin_pass);
|
||||
ASSERT_NE(begin_pass_test, nullptr);
|
||||
auto res = begin_pass_test->Run(nullptr);
|
||||
ASSERT_EQ(res, true);
|
||||
auto end_pass = passes[opt::POSITION_END];
|
||||
ASSERT_NE(end_pass, nullptr);
|
||||
auto end_pass_test = std::dynamic_pointer_cast<opt::TestFusion>(end_pass);
|
||||
ASSERT_NE(end_pass_test, nullptr);
|
||||
res = end_pass_test->Run(nullptr);
|
||||
ASSERT_NE(func_graph_, nullptr);
|
||||
auto res = begin_pass_test->Run(func_graph_);
|
||||
ASSERT_EQ(res, true);
|
||||
auto cnode_list = func_graph_->GetOrderedCnodes();
|
||||
ASSERT_EQ(cnode_list.size(), 2);
|
||||
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
|
||||
ASSERT_EQ(is_custom, true);
|
||||
auto custome_prim = GetValueNode<std::shared_ptr<ops::Custom>>(cnode_list.front()->input(0));
|
||||
ASSERT_NE(custome_prim, nullptr);
|
||||
auto type = custome_prim->get_type();
|
||||
ASSERT_EQ(type, std::string("Custom_AddN"));
|
||||
bool is_return = opt::CheckPrimitiveType(cnode_list.back(), prim::kPrimReturn);
|
||||
ASSERT_EQ(is_return, true);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue