forked from mindspore-Ecosystem/mindspore
add subgraph util for func_graph
This commit is contained in:
parent
e6e544dbc4
commit
0e62e061a2
|
@ -0,0 +1,555 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/common/func_graph_subgraph.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ops/fusion/partial_fusion.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set<CNodePtr> &head_nodes)
|
||||
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {
|
||||
InitSubGraphNode(head_nodes);
|
||||
InitSubGraphInNode();
|
||||
InitSubGraphOutNode();
|
||||
}
|
||||
|
||||
void SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
|
||||
this->nodes_ = nodes;
|
||||
InitSubGraphNode(head_nodes);
|
||||
InitSubGraphInNode();
|
||||
InitSubGraphOutNode();
|
||||
}
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetInCNodes() const { return this->in_nodes_; }
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetInputCNodes() const {
|
||||
std::set<CNodePtr> inputs;
|
||||
for (const auto &in_node : in_nodes_) {
|
||||
if (in_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnodes = GetInputCNode(in_node);
|
||||
inputs.insert(input_cnodes.begin(), input_cnodes.end());
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetOutCNodes() const { return this->out_nodes_; }
|
||||
|
||||
std::set<CNodePtr> SubGraph::FindCommonOutputs(const SubGraphPtr &subgraph) const {
|
||||
if (subgraph == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::set<CNodePtr> outputs_this = this->GetOutputCNodes();
|
||||
if (this == subgraph.get()) {
|
||||
return outputs_this;
|
||||
}
|
||||
std::set<CNodePtr> outputs_other = subgraph->GetOutputCNodes();
|
||||
std::set<CNodePtr> common_outputs;
|
||||
for (const auto &output1 : outputs_this) {
|
||||
if (output1 == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto iter = outputs_other.find(output1);
|
||||
if (iter == outputs_other.end()) {
|
||||
continue;
|
||||
}
|
||||
if (GetInputCNode(output1).size() == 2) {
|
||||
common_outputs.insert(output1);
|
||||
}
|
||||
}
|
||||
return common_outputs;
|
||||
}
|
||||
|
||||
bool SubGraph::IfDependOnSameNode(const SubGraphPtr &subgraph) const {
|
||||
if (subgraph == nullptr || this == subgraph.get()) {
|
||||
return false;
|
||||
}
|
||||
std::set<CNodePtr> inputs_this = this->GetInputCNodes();
|
||||
std::set<CNodePtr> inputs_other = subgraph->GetInputCNodes();
|
||||
return std::any_of(inputs_this.begin(), inputs_this.end(), [&inputs_other](const CNodePtr &input_this) {
|
||||
if (input_this == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return (inputs_other.count(input_this) > 0);
|
||||
});
|
||||
}
|
||||
|
||||
std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
std::set<CNodePtr> outputs;
|
||||
for (const auto &out_node : out_nodes_) {
|
||||
if (out_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto iter = node_users.find(out_node);
|
||||
if (iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
auto post_node_pairs = iter->second;
|
||||
for (const auto &post_node_pair : post_node_pairs) {
|
||||
auto post_node = post_node_pair.first;
|
||||
if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
|
||||
continue;
|
||||
}
|
||||
outputs.insert(utils::cast<CNodePtr>(post_node));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
std::queue<CNodePtr> q;
|
||||
for (const auto &head_node : head_nodes) {
|
||||
if (head_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
q.push(head_node);
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto cur_node = q.front();
|
||||
MS_ASSERT(cur_node != nullptr);
|
||||
q.pop();
|
||||
this->nodes_.insert(cur_node);
|
||||
// check output-cnode of cur-node only depend on cur-node
|
||||
auto iter = node_users.find(cur_node);
|
||||
if (iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
auto post_node_pairs = iter->second;
|
||||
for (const auto &post_node_pair : post_node_pairs) {
|
||||
auto post_node = post_node_pair.first;
|
||||
if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
|
||||
continue;
|
||||
}
|
||||
auto post_cnode = utils::cast<CNodePtr>(post_node);
|
||||
// return-node should not be include into subgraph absolutely // ut
|
||||
if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
MS_ASSERT(post_cnode != nullptr);
|
||||
bool non_depend = true;
|
||||
// check all inputs of output-cnode
|
||||
for (const auto &input : post_cnode->inputs()) {
|
||||
if (input == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// input cnode is not contained in subgraph
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
if (this->nodes_.count(input_cnode) == 0) {
|
||||
non_depend = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// input parameter is a graph input
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto input_parameter = utils::cast<ParameterPtr>(input);
|
||||
if (!input_parameter->has_default()) {
|
||||
non_depend = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (non_depend) {
|
||||
q.push(post_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphInNode() {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
this->in_nodes_.clear();
|
||||
for (const auto &node : this->nodes_) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (std::any_of(node->inputs().begin(), node->inputs().end(), [this, &node_users](const auto &input) {
|
||||
if (input == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
if (this->nodes_.count(input_cnode) == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// graph input or shared weight input // ut
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto input_parameter = utils::cast<ParameterPtr>(input);
|
||||
if (!input_parameter->has_default()) {
|
||||
return true;
|
||||
}
|
||||
auto output_pair_iter = node_users.find(input);
|
||||
if (output_pair_iter != node_users.end() && output_pair_iter->second.size() > 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
})) {
|
||||
in_nodes_.insert(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SubGraph::InitSubGraphOutNode() {
|
||||
MS_ASSERT(belong_anf_ != nullptr);
|
||||
MS_ASSERT(belong_anf_->manager() != nullptr);
|
||||
auto node_users = belong_anf_->manager()->node_users();
|
||||
this->out_nodes_.clear();
|
||||
for (const auto &node : this->nodes_) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node_users_iter = node_users.find(node);
|
||||
if (node_users_iter == node_users.end()) {
|
||||
continue;
|
||||
}
|
||||
auto node_output_pairs = node_users_iter->second;
|
||||
if (!std::any_of(node_output_pairs.begin(), node_output_pairs.end(),
|
||||
[this](const std::pair<AnfNodePtr, int> &output_pair) {
|
||||
auto output_node = output_pair.first;
|
||||
if (output_node == nullptr || !utils::isa<CNodePtr>(output_node)) {
|
||||
return false;
|
||||
}
|
||||
// graph output // ut
|
||||
if (opt::CheckPrimitiveType(output_node, prim::kPrimReturn)) {
|
||||
return true;
|
||||
}
|
||||
auto output_cnode = utils::cast<CNodePtr>(output_node);
|
||||
if (this->nodes_.count(output_cnode) == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}))
|
||||
continue;
|
||||
out_nodes_.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
|
||||
if (subgraph == nullptr || this == subgraph.get()) {
|
||||
return false;
|
||||
}
|
||||
// if two subgraph has same output, and this output node has only two input cnode which exactly from two
|
||||
// subgraph, we merge two subgraph, and find more post node
|
||||
auto common_outputs = this->FindCommonOutputs(subgraph);
|
||||
if (!common_outputs.empty()) {
|
||||
auto new_nodes = this->GetNodes();
|
||||
auto new_nodes2 = subgraph->GetNodes();
|
||||
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
|
||||
new_nodes.insert(common_outputs.begin(), common_outputs.end());
|
||||
this->Reset(new_nodes, common_outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this->IfDependOnSameNode(subgraph)) {
|
||||
auto new_nodes = this->GetNodes();
|
||||
auto new_nodes2 = subgraph->GetNodes();
|
||||
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
|
||||
this->Reset(new_nodes);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// iterate node from in_nodes of current subgraph up to input of belong_anf
|
||||
SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
|
||||
MS_ASSERT(belong_anf_ == nullptr);
|
||||
// find before subgraph's nodes
|
||||
std::queue<CNodePtr> q;
|
||||
std::set<CNodePtr> before_nodes;
|
||||
for (const auto &node : this->GetInCNodes()) {
|
||||
for (const auto &in_cnode : lite::GetInputCNode(node)) {
|
||||
if (in_cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
q.push(in_cnode);
|
||||
}
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto cur_cnode = q.front();
|
||||
MS_ASSERT(cur_cnode != nullptr);
|
||||
q.pop();
|
||||
before_nodes.insert(cur_cnode);
|
||||
for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
|
||||
q.push(in_cnode);
|
||||
}
|
||||
}
|
||||
// construct before subgraph
|
||||
auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
|
||||
before_subgraph->Reset(before_nodes);
|
||||
return before_subgraph;
|
||||
}
|
||||
|
||||
// iterate node from output of belong_anf up to out_nodes of current subgraph and before subgraph
|
||||
SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
|
||||
MS_ASSERT(belong_anf_ == nullptr);
|
||||
// find before subgraph
|
||||
auto before_subgraph = this->FindBeforeSubGraphInBelongAnf();
|
||||
if (before_subgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "Find before subgraph failed";
|
||||
return nullptr;
|
||||
}
|
||||
// find after subgraph's nodes
|
||||
std::queue<CNodePtr> q;
|
||||
std::set<CNodePtr> after_nodes;
|
||||
auto output_node = belong_anf_->output();
|
||||
if (!utils::isa<CNodePtr>(output_node)) {
|
||||
MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
q.push(utils::cast<CNodePtr>(output_node));
|
||||
auto subgraph_out_nodes = this->GetOutCNodes();
|
||||
auto before_out_nodes = before_subgraph->GetOutCNodes();
|
||||
while (!q.empty()) {
|
||||
auto cur_cnode = q.front();
|
||||
MS_ASSERT(cur_cnode != nullptr);
|
||||
q.pop();
|
||||
after_nodes.insert(cur_cnode);
|
||||
for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
|
||||
if (subgraph_out_nodes.count(in_cnode) == 0 && before_out_nodes.count(in_cnode) == 0) {
|
||||
q.push(in_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
// construct before subgraph
|
||||
auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
|
||||
after_subgraph->Reset(after_nodes);
|
||||
return after_subgraph;
|
||||
}
|
||||
|
||||
int SubGraph::CreatePartialInBelongAnf() {
|
||||
MS_ASSERT(this->belong_anf_ != nullptr);
|
||||
MS_ASSERT(this->belong_anf_->manager() != nullptr);
|
||||
// determine func_graph name
|
||||
std::string graph_name = this->name_;
|
||||
if (graph_name.empty()) {
|
||||
if (this->nodes_.empty()) {
|
||||
graph_name = "subgraph";
|
||||
} else {
|
||||
graph_name = (*(this->nodes_.begin()))->fullname_with_scope() + "/subgraph";
|
||||
}
|
||||
}
|
||||
// create func_graph of partial
|
||||
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
|
||||
auto manager = belong_anf_->manager();
|
||||
manager->AddFuncGraph(func_graph);
|
||||
func_graph->set_attr("graph_name", MakeValue(graph_name));
|
||||
func_graph->set_manager(manager);
|
||||
// create cnode and parameter for func_graph of partial
|
||||
std::vector<AnfNodePtr> partial_inputs;
|
||||
std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
|
||||
CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
|
||||
CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
|
||||
// add return for func_graph of partial
|
||||
auto sub_graph_outputs = this->GetOutCNodes();
|
||||
MS_ASSERT(!sub_graph_outputs.empty());
|
||||
auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Set subgraph output failed";
|
||||
return ret;
|
||||
}
|
||||
// create partial cnode
|
||||
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
|
||||
auto graph_value_node = NewValueNode(func_graph);
|
||||
partial_inputs.insert(partial_inputs.begin(), graph_value_node);
|
||||
auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs);
|
||||
partial_cnode->set_fullname_with_scope(graph_name + "/partial");
|
||||
for (size_t i = 0; i < partial_inputs.size(); ++i) {
|
||||
const auto &input = partial_inputs.at(i);
|
||||
manager->SetEdge(partial_cnode, static_cast<int>(i + 1), input);
|
||||
}
|
||||
// create call cnode
|
||||
std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
|
||||
auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
|
||||
call_cnode->set_fullname_with_scope(graph_name + "/call");
|
||||
// replace belong-graph's output
|
||||
auto return_node = belong_anf_->get_return();
|
||||
MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2);
|
||||
auto ori_output = return_node->inputs().at(1);
|
||||
MS_ASSERT(ori_output != nullptr);
|
||||
manager->Replace(ori_output, call_cnode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs) {
|
||||
std::vector<AnfNodePtr> output_nodes;
|
||||
output_nodes.insert(output_nodes.end(), outputs.begin(), outputs.end());
|
||||
return lite::SetFuncGraphOutput(graph, output_nodes);
|
||||
}
|
||||
|
||||
void SubGraph::CreateParameterForPartialSubGraph(
|
||||
const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
|
||||
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
MS_ASSERT(partial_inputs != nullptr && partial_inputs->empty());
|
||||
MS_ASSERT(partial_inputs_and_subgraph_input_map != nullptr && partial_inputs_and_subgraph_input_map->empty());
|
||||
|
||||
std::string graph_name = sub_graph->get_attr("graph_name")->ToString();
|
||||
for (const auto &in_cnode : this->GetInCNodes()) {
|
||||
if (in_cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < in_cnode->inputs().size(); i++) {
|
||||
auto input = in_cnode->input(i);
|
||||
if (input == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto iter = partial_inputs_and_subgraph_input_map->find(input);
|
||||
if (iter != partial_inputs_and_subgraph_input_map->end()) {
|
||||
continue;
|
||||
}
|
||||
// create subgraph input parameter from cnode and record partial inputs
|
||||
if (utils::isa<CNodePtr>(input)) {
|
||||
auto input_cnode = utils::cast<CNodePtr>(input);
|
||||
if (this->GetNodes().count(input_cnode) > 0) {
|
||||
continue;
|
||||
}
|
||||
partial_inputs->emplace_back(input);
|
||||
auto new_parameter = sub_graph->add_parameter();
|
||||
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
|
||||
new_parameter->set_abstract(input->abstract());
|
||||
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
|
||||
}
|
||||
// create subgraph input parameter from parameter and record partial inputs
|
||||
// add parameter to func_graph
|
||||
auto node_users = this->belong_anf_->manager()->node_users();
|
||||
if (utils::isa<ParameterPtr>(input)) {
|
||||
auto parameter = utils::cast<ParameterPtr>(input);
|
||||
// graph input: create a parameter
|
||||
if (!parameter->has_default()) {
|
||||
auto new_parameter = sub_graph->add_parameter();
|
||||
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
|
||||
new_parameter->set_abstract(input->abstract());
|
||||
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
|
||||
partial_inputs->emplace_back(new_parameter);
|
||||
}
|
||||
// weight parameter, it depends
|
||||
auto output_pairs_iter = node_users.find(input);
|
||||
if (output_pairs_iter != node_users.end() &&
|
||||
output_pairs_iter->second.size() > 1) { // shared weight: create a parameter
|
||||
auto new_parameter = sub_graph->add_parameter();
|
||||
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
|
||||
new_parameter->set_abstract(input->abstract());
|
||||
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
|
||||
partial_inputs->emplace_back(new_parameter);
|
||||
} else { // not shared weight: move into subgraph
|
||||
sub_graph->AddNode(input);
|
||||
input->set_func_graph(sub_graph);
|
||||
this->belong_anf_->DropNode(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SubGraph::CreateCNodeForPartialSubGraph(
|
||||
const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
// move cnode from belong_graph to subgraph
|
||||
for (auto &node : this->GetNodes()) {
|
||||
sub_graph->AddNode(node);
|
||||
node->set_func_graph(sub_graph);
|
||||
for (size_t i = 0; i < node->inputs().size(); i++) {
|
||||
if (node == nullptr || node->inputs().at(i)) {
|
||||
continue;
|
||||
}
|
||||
auto input = node->inputs().at(i);
|
||||
auto iter = partial_inputs_and_subgraph_input_map.find(input);
|
||||
if (iter == partial_inputs_and_subgraph_input_map.end()) {
|
||||
continue;
|
||||
}
|
||||
// use SetEdge not set_input, if not, node_user is not updated.
|
||||
this->belong_anf_->manager()->SetEdge(node, static_cast<int>(i), iter->second);
|
||||
}
|
||||
this->belong_anf_->DropNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
int SubGraph::ApplySubGraph() {
|
||||
// check
|
||||
if (this->nodes_.empty()) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
if (belong_anf_ == nullptr || belong_anf_->manager() == nullptr) {
|
||||
MS_LOG(DEBUG) << "belong_anf_ or manager is nullptr";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
for (const auto &node : this->nodes_) {
|
||||
if (node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (node->func_graph() != belong_anf_) {
|
||||
MS_LOG(DEBUG) << "subgraph nodes belong to different func_graph";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
// create after partial // redirect input of after subgraph
|
||||
auto after_subgraph = this->FindAfterSubGraphInBelongAnf();
|
||||
if (after_subgraph == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create after subgraph failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = after_subgraph->CreatePartialInBelongAnf();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Create after partial failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// merge after partial into subgraph
|
||||
auto subgraph_nodes = this->nodes_;
|
||||
auto return_node = belong_anf_->get_return();
|
||||
MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2);
|
||||
auto call_node = return_node->inputs().at(1);
|
||||
MS_ASSERT(call_node != nullptr && utils::isa<CNodePtr>(call_node));
|
||||
auto call_cnode = utils::cast<CNodePtr>(call_node);
|
||||
MS_ASSERT(call_cnode != nullptr && call_cnode->inputs().size() == 1);
|
||||
auto after_partial_node = call_cnode->inputs().at(0);
|
||||
MS_ASSERT(after_partial_node != nullptr && utils::isa<CNodePtr>(after_partial));
|
||||
auto after_partial_cnode = utils::cast<CNodePtr>(after_partial_node);
|
||||
MS_ASSERT(after_partial_cnode != nullptr);
|
||||
subgraph_nodes.insert(after_partial_cnode);
|
||||
subgraph_nodes.insert(call_cnode);
|
||||
this->Reset(subgraph_nodes);
|
||||
// create subgraph partial // add partial to main subgraph
|
||||
ret = this->CreatePartialInBelongAnf();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Create partial failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H
|
||||
#define MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class SubGraph;
|
||||
using SubGraphPtr = std::shared_ptr<SubGraph>;
|
||||
class SubGraph {
|
||||
public:
|
||||
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set<CNodePtr> &head_nodes = {});
|
||||
|
||||
void Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
|
||||
|
||||
bool MergeSubGraph(const SubGraphPtr &subgraph);
|
||||
|
||||
std::set<CNodePtr> GetNodes() const;
|
||||
std::set<CNodePtr> GetInCNodes() const;
|
||||
std::set<CNodePtr> GetOutCNodes() const;
|
||||
|
||||
int ApplySubGraph();
|
||||
|
||||
private:
|
||||
std::set<CNodePtr> GetInputCNodes() const;
|
||||
std::set<CNodePtr> GetOutputCNodes() const;
|
||||
// init subgraph methods
|
||||
void InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
|
||||
void InitSubGraphInNode();
|
||||
void InitSubGraphOutNode();
|
||||
// merge subgraph methods
|
||||
std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const;
|
||||
bool IfDependOnSameNode(const SubGraphPtr &subgraph) const;
|
||||
// apply subgraph methods
|
||||
SubGraphPtr FindBeforeSubGraphInBelongAnf() const;
|
||||
SubGraphPtr FindAfterSubGraphInBelongAnf() const;
|
||||
void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
|
||||
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
|
||||
void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
|
||||
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
|
||||
int CreatePartialInBelongAnf();
|
||||
static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs);
|
||||
|
||||
private:
|
||||
std::set<CNodePtr> nodes_;
|
||||
std::set<CNodePtr> in_nodes_;
|
||||
std::set<CNodePtr> out_nodes_;
|
||||
const FuncGraphPtr belong_anf_;
|
||||
const std::string name_;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H
|
|
@ -26,6 +26,7 @@
|
|||
#include "tools/common/node_util.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -33,6 +34,29 @@ namespace {
|
|||
enum QuantBitNum { QuantBitNum_INT8 = 8, QuantBitNum_INT16 = 16 };
|
||||
const int kZeroPointGap = 128;
|
||||
} // namespace
|
||||
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
|
||||
if (graph == nullptr || outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (outputs.size() == 1) {
|
||||
graph->set_output(outputs.front(), false);
|
||||
return RET_OK;
|
||||
}
|
||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "new MakeTuple failed";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_ptr, outputs);
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "new cnode failed";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||
graph->set_output(make_tuple_cnode, false);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OpDefCopyer GetSimpleOpCopyer() {
|
||||
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
|
||||
|
|
|
@ -46,6 +46,8 @@ using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(schema::CNodeT
|
|||
|
||||
OpDefCopyer GetSimpleOpCopyer();
|
||||
|
||||
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
|
||||
|
||||
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
|
||||
|
|
|
@ -28,147 +28,19 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr size_t kInitialSize = 1024;
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion,
|
||||
schema::PrimitiveType_Conv2DBackpropInputFusion,
|
||||
schema::PrimitiveType_AvgPoolGrad,
|
||||
schema::PrimitiveType_MaxPoolGrad,
|
||||
schema::PrimitiveType_BiasAddGrad,
|
||||
schema::PrimitiveType_BatchNormGrad,
|
||||
schema::PrimitiveType_ApplyMomentum,
|
||||
schema::PrimitiveType_SGD,
|
||||
schema::PrimitiveType_Adam,
|
||||
schema::PrimitiveType_ResizeGrad,
|
||||
schema::PrimitiveType_AvgPoolFusion,
|
||||
schema::PrimitiveType_MaxPoolFusion,
|
||||
schema::PrimitiveType_Conv2DFusion,
|
||||
schema::PrimitiveType_Conv2dTransposeFusion,
|
||||
schema::PrimitiveType_LRN,
|
||||
schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_BatchNorm,
|
||||
schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_PReLUFusion,
|
||||
schema::PrimitiveType_BiasAdd,
|
||||
schema::PrimitiveType_SpaceToDepth,
|
||||
schema::PrimitiveType_DepthToSpace,
|
||||
schema::PrimitiveType_TopKFusion,
|
||||
schema::PrimitiveType_BatchToSpace,
|
||||
schema::PrimitiveType_SpaceToBatch,
|
||||
schema::PrimitiveType_SpaceToBatchND};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
|
||||
schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad,
|
||||
schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion,
|
||||
schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad};
|
||||
|
||||
// index {} mean all inputs need insert
|
||||
static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = {
|
||||
{schema::PrimitiveType_BatchNormGrad, {0, 1}},
|
||||
{schema::PrimitiveType_Conv2DBackpropFilterFusion, {0, 1}},
|
||||
{schema::PrimitiveType_ApplyMomentum, {3}},
|
||||
{schema::PrimitiveType_SGD, {1}},
|
||||
{schema::PrimitiveType_Adam, {9}}};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion,
|
||||
schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32
|
||||
|
||||
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2DFusion,
|
||||
schema::PrimitiveType_Conv2dTransposeFusion,
|
||||
schema::PrimitiveType_AddFusion,
|
||||
schema::PrimitiveType_Transpose,
|
||||
schema::PrimitiveType_AvgPoolFusion,
|
||||
schema::PrimitiveType_MaxPoolFusion,
|
||||
schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_Softmax,
|
||||
schema::PrimitiveType_Reshape,
|
||||
schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_ArgMaxFusion,
|
||||
schema::PrimitiveType_ArgMinFusion,
|
||||
schema::PrimitiveType_BatchNorm,
|
||||
schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_BiasAdd,
|
||||
schema::PrimitiveType_DivFusion,
|
||||
schema::PrimitiveType_MulFusion,
|
||||
schema::PrimitiveType_SliceFusion,
|
||||
schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Squeeze,
|
||||
schema::PrimitiveType_SubFusion,
|
||||
schema::PrimitiveType_StridedSlice,
|
||||
schema::PrimitiveType_TopKFusion,
|
||||
schema::PrimitiveType_Unsqueeze,
|
||||
schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_PadFusion,
|
||||
schema::PrimitiveType_ScaleFusion,
|
||||
schema::PrimitiveType_Cast,
|
||||
schema::PrimitiveType_Shape,
|
||||
schema::PrimitiveType_ExpandDims,
|
||||
schema::PrimitiveType_BatchToSpace,
|
||||
schema::PrimitiveType_BatchToSpaceND,
|
||||
schema::PrimitiveType_ReduceFusion,
|
||||
schema::PrimitiveType_Round,
|
||||
schema::PrimitiveType_Floor,
|
||||
schema::PrimitiveType_Ceil,
|
||||
schema::PrimitiveType_Abs,
|
||||
schema::PrimitiveType_Sin,
|
||||
schema::PrimitiveType_Cos,
|
||||
schema::PrimitiveType_Log,
|
||||
schema::PrimitiveType_Sqrt,
|
||||
schema::PrimitiveType_Rsqrt,
|
||||
schema::PrimitiveType_Square,
|
||||
schema::PrimitiveType_LogicalNot,
|
||||
schema::PrimitiveType_SpaceToBatch,
|
||||
schema::PrimitiveType_SpaceToBatchND,
|
||||
schema::PrimitiveType_DepthToSpace,
|
||||
schema::PrimitiveType_PowFusion,
|
||||
schema::PrimitiveType_GatherNd,
|
||||
schema::PrimitiveType_LeakyRelu,
|
||||
schema::PrimitiveType_Gather,
|
||||
schema::PrimitiveType_Equal,
|
||||
schema::PrimitiveType_NotEqual,
|
||||
schema::PrimitiveType_LessEqual,
|
||||
schema::PrimitiveType_Greater,
|
||||
schema::PrimitiveType_GreaterEqual,
|
||||
schema::PrimitiveType_Eltwise,
|
||||
schema::PrimitiveType_DetectionPostProcess,
|
||||
schema::PrimitiveType_Crop,
|
||||
schema::PrimitiveType_PriorBox,
|
||||
schema::PrimitiveType_QuantDTypeCast,
|
||||
schema::PrimitiveType_LayerNormFusion,
|
||||
schema::PrimitiveType_L2NormalizeFusion};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
|
||||
schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
|
||||
schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
|
||||
schema::PrimitiveType_ActivationGrad};
|
||||
|
||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||
|
||||
std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; }
|
||||
|
||||
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
|
||||
std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return {};
|
||||
}
|
||||
std::vector<CNodePtr> inputs;
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
if (input == nullptr || !utils::isa<CNodePtr>(input)) {
|
||||
continue;
|
||||
}
|
||||
inputs.emplace_back(utils::cast<CNodePtr>(input));
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
if (primitive_t == nullptr || fbb == nullptr) {
|
||||
|
|
|
@ -31,6 +31,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
|
||||
|
||||
template <typename T>
|
||||
int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
|
||||
auto attr = std::make_unique<T>();
|
||||
|
|
|
@ -27,6 +27,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/func_graph_subgraph.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc
|
||||
|
@ -112,6 +113,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/transpose_strategy.cc
|
||||
../optimizer/graph/reduce_same_act_pass.cc
|
||||
../optimizer/graph/split_one_pass.cc
|
||||
../optimizer/graph/find_const_subgraph_pass.cc
|
||||
)
|
||||
|
||||
add_subdirectory(../anf_exporter anf_exporter)
|
||||
|
|
Loading…
Reference in New Issue