From 0e62e061a288bffbe03f6d63a4d6c8eb05c0d4ea Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Mon, 2 Aug 2021 17:01:31 +0800 Subject: [PATCH] add subgraph util for func_graph --- .../lite/tools/common/func_graph_subgraph.cc | 555 ++++++++++++++++++ .../lite/tools/common/func_graph_subgraph.h | 75 +++ mindspore/lite/tools/common/graph_util.cc | 24 + mindspore/lite/tools/common/graph_util.h | 2 + mindspore/lite/tools/common/node_util.cc | 154 +---- mindspore/lite/tools/common/node_util.h | 2 + mindspore/lite/tools/converter/CMakeLists.txt | 2 + 7 files changed, 673 insertions(+), 141 deletions(-) create mode 100644 mindspore/lite/tools/common/func_graph_subgraph.cc create mode 100644 mindspore/lite/tools/common/func_graph_subgraph.h diff --git a/mindspore/lite/tools/common/func_graph_subgraph.cc b/mindspore/lite/tools/common/func_graph_subgraph.cc new file mode 100644 index 00000000000..79d900fa277 --- /dev/null +++ b/mindspore/lite/tools/common/func_graph_subgraph.cc @@ -0,0 +1,555 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/common/func_graph_subgraph.h" +#include +#include +#include +#include +#include +#include +#include "src/common/log_adapter.h" +#include "tools/common/node_util.h" +#include "tools/common/graph_util.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "ops/fusion/partial_fusion.h" + +namespace mindspore::lite { +SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set &head_nodes) + : belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) { + InitSubGraphNode(head_nodes); + InitSubGraphInNode(); + InitSubGraphOutNode(); +} + +void SubGraph::Reset(const std::set &nodes, const std::set &head_nodes) { + this->nodes_ = nodes; + InitSubGraphNode(head_nodes); + InitSubGraphInNode(); + InitSubGraphOutNode(); +} + +std::set SubGraph::GetNodes() const { return this->nodes_; } + +std::set SubGraph::GetInCNodes() const { return this->in_nodes_; } + +std::set SubGraph::GetInputCNodes() const { + std::set inputs; + for (const auto &in_node : in_nodes_) { + if (in_node == nullptr) { + continue; + } + auto input_cnodes = GetInputCNode(in_node); + inputs.insert(input_cnodes.begin(), input_cnodes.end()); + } + return inputs; +} + +std::set SubGraph::GetOutCNodes() const { return this->out_nodes_; } + +std::set SubGraph::FindCommonOutputs(const SubGraphPtr &subgraph) const { + if (subgraph == nullptr) { + return {}; + } + std::set outputs_this = this->GetOutputCNodes(); + if (this == subgraph.get()) { + return outputs_this; + } + std::set outputs_other = subgraph->GetOutputCNodes(); + std::set common_outputs; + for (const auto &output1 : outputs_this) { + if (output1 == nullptr) { + continue; + } + auto iter = outputs_other.find(output1); + if (iter == outputs_other.end()) { + continue; + } + if (GetInputCNode(output1).size() == 2) { + common_outputs.insert(output1); + } + } + return common_outputs; +} + +bool SubGraph::IfDependOnSameNode(const SubGraphPtr &subgraph) const { + if (subgraph == nullptr || this == subgraph.get()) { + return false; + } + std::set inputs_this = this->GetInputCNodes(); + std::set inputs_other = subgraph->GetInputCNodes(); + return std::any_of(inputs_this.begin(), inputs_this.end(), [&inputs_other](const CNodePtr &input_this) { + if (input_this == nullptr) { + return false; + } + return (inputs_other.count(input_this) > 0); + }); +} + +std::set SubGraph::GetOutputCNodes() const { + MS_ASSERT(belong_anf_ != nullptr); + MS_ASSERT(belong_anf_->manager() != nullptr); + auto node_users = belong_anf_->manager()->node_users(); + std::set outputs; + for (const auto &out_node : out_nodes_) { + if (out_node == nullptr) { + continue; + } + auto iter = node_users.find(out_node); + if (iter == node_users.end()) { + continue; + } + auto post_node_pairs = iter->second; + for (const auto &post_node_pair : post_node_pairs) { + auto post_node = post_node_pair.first; + if (post_node == nullptr || !utils::isa(post_node)) { + continue; + } + outputs.insert(utils::cast(post_node)); + } + } + return outputs; +} + +void SubGraph::InitSubGraphNode(const std::set &head_nodes) { + MS_ASSERT(belong_anf_ != nullptr); + MS_ASSERT(belong_anf_->manager() != nullptr); + auto node_users = belong_anf_->manager()->node_users(); + std::queue q; + for (const auto &head_node : head_nodes) { + if (head_node == nullptr) { + continue; + } + q.push(head_node); + } + while (!q.empty()) { + auto cur_node = q.front(); + MS_ASSERT(cur_node != nullptr); + q.pop(); + this->nodes_.insert(cur_node); + // check output-cnode of cur-node only depend on cur-node + auto iter = node_users.find(cur_node); + if (iter == node_users.end()) { + continue; + } + auto post_node_pairs = iter->second; + for (const auto &post_node_pair : post_node_pairs) { + auto post_node = post_node_pair.first; + if (post_node == nullptr || !utils::isa(post_node)) { + continue; + } + auto post_cnode = utils::cast(post_node); + // return-node should not be include into subgraph absolutely // ut + if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) { + continue; + } + MS_ASSERT(post_cnode != nullptr); + bool non_depend = true; + // check all inputs of output-cnode + for (const auto &input : post_cnode->inputs()) { + if (input == nullptr) { + continue; + } + // input cnode is not contained in subgraph + if (utils::isa(input)) { + auto input_cnode = utils::cast(input); + if (this->nodes_.count(input_cnode) == 0) { + non_depend = false; + break; + } + } + // input parameter is a graph input + if (utils::isa(input)) { + auto input_parameter = utils::cast(input); + if (!input_parameter->has_default()) { + non_depend = false; + break; + } + } + } + if (non_depend) { + q.push(post_cnode); + } + } + } +} + +void SubGraph::InitSubGraphInNode() { + MS_ASSERT(belong_anf_ != nullptr); + MS_ASSERT(belong_anf_->manager() != nullptr); + auto node_users = belong_anf_->manager()->node_users(); + this->in_nodes_.clear(); + for (const auto &node : this->nodes_) { + if (node == nullptr) { + continue; + } + if (std::any_of(node->inputs().begin(), node->inputs().end(), [this, &node_users](const auto &input) { + if (input == nullptr) { + return false; + } + if (utils::isa(input)) { + auto input_cnode = utils::cast(input); + if (this->nodes_.count(input_cnode) == 0) { + return true; + } + } + // graph input or shared weight input // ut + if (utils::isa(input)) { + auto input_parameter = utils::cast(input); + if (!input_parameter->has_default()) { + return true; + } + auto output_pair_iter = node_users.find(input); + if (output_pair_iter != node_users.end() && output_pair_iter->second.size() > 1) { + return true; + } + } + return false; + })) { + in_nodes_.insert(node); + } + } +} + +void SubGraph::InitSubGraphOutNode() { + MS_ASSERT(belong_anf_ != nullptr); + MS_ASSERT(belong_anf_->manager() != nullptr); + auto node_users = belong_anf_->manager()->node_users(); + this->out_nodes_.clear(); + for (const auto &node : this->nodes_) { + if (node == nullptr) { + continue; + } + auto node_users_iter = node_users.find(node); + if (node_users_iter == node_users.end()) { + continue; + } + auto node_output_pairs = node_users_iter->second; + if (!std::any_of(node_output_pairs.begin(), node_output_pairs.end(), + [this](const std::pair &output_pair) { + auto output_node = output_pair.first; + if (output_node == nullptr || !utils::isa(output_node)) { + return false; + } + // graph output // ut + if (opt::CheckPrimitiveType(output_node, prim::kPrimReturn)) { + return true; + } + auto output_cnode = utils::cast(output_node); + if (this->nodes_.count(output_cnode) == 0) { + return true; + } + return false; + })) + continue; + out_nodes_.insert(node); + } +} + +bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) { + if (subgraph == nullptr || this == subgraph.get()) { + return false; + } + // if two subgraph has same output, and this output node has only two input cnode which exactly from two + // subgraph, we merge two subgraph, and find more post node + auto common_outputs = this->FindCommonOutputs(subgraph); + if (!common_outputs.empty()) { + auto new_nodes = this->GetNodes(); + auto new_nodes2 = subgraph->GetNodes(); + new_nodes.insert(new_nodes2.begin(), new_nodes2.end()); + new_nodes.insert(common_outputs.begin(), common_outputs.end()); + this->Reset(new_nodes, common_outputs); + return true; + } + + if (this->IfDependOnSameNode(subgraph)) { + auto new_nodes = this->GetNodes(); + auto new_nodes2 = subgraph->GetNodes(); + new_nodes.insert(new_nodes2.begin(), new_nodes2.end()); + this->Reset(new_nodes); + return true; + } + return false; +} + +// iterate node from in_nodes of current subgraph up to input of belong_anf +SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const { + MS_ASSERT(belong_anf_ == nullptr); + // find before subgraph's nodes + std::queue q; + std::set before_nodes; + for (const auto &node : this->GetInCNodes()) { + for (const auto &in_cnode : lite::GetInputCNode(node)) { + if (in_cnode == nullptr) { + continue; + } + q.push(in_cnode); + } + } + while (!q.empty()) { + auto cur_cnode = q.front(); + MS_ASSERT(cur_cnode != nullptr); + q.pop(); + before_nodes.insert(cur_cnode); + for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) { + q.push(in_cnode); + } + } + // construct before subgraph + auto before_subgraph = std::make_shared(belong_anf_, this->name_ + "/before_subgraph"); + before_subgraph->Reset(before_nodes); + return before_subgraph; +} + +// iterate node from output of belong_anf up to out_nodes of current subgraph and before subgraph +SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const { + MS_ASSERT(belong_anf_ == nullptr); + // find before subgraph + auto before_subgraph = this->FindBeforeSubGraphInBelongAnf(); + if (before_subgraph == nullptr) { + MS_LOG(ERROR) << "Find before subgraph failed"; + return nullptr; + } + // find after subgraph's nodes + std::queue q; + std::set after_nodes; + auto output_node = belong_anf_->output(); + if (!utils::isa(output_node)) { + MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope(); + return nullptr; + } + q.push(utils::cast(output_node)); + auto subgraph_out_nodes = this->GetOutCNodes(); + auto before_out_nodes = before_subgraph->GetOutCNodes(); + while (!q.empty()) { + auto cur_cnode = q.front(); + MS_ASSERT(cur_cnode != nullptr); + q.pop(); + after_nodes.insert(cur_cnode); + for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) { + if (subgraph_out_nodes.count(in_cnode) == 0 && before_out_nodes.count(in_cnode) == 0) { + q.push(in_cnode); + } + } + } + // construct before subgraph + auto after_subgraph = std::make_shared(belong_anf_, this->name_ + "/after_subgraph"); + after_subgraph->Reset(after_nodes); + return after_subgraph; +} + +int SubGraph::CreatePartialInBelongAnf() { + MS_ASSERT(this->belong_anf_ != nullptr); + MS_ASSERT(this->belong_anf_->manager() != nullptr); + // determine func_graph name + std::string graph_name = this->name_; + if (graph_name.empty()) { + if (this->nodes_.empty()) { + graph_name = "subgraph"; + } else { + graph_name = (*(this->nodes_.begin()))->fullname_with_scope() + "/subgraph"; + } + } + // create func_graph of partial + FuncGraphPtr func_graph = std::make_shared(); + auto manager = belong_anf_->manager(); + manager->AddFuncGraph(func_graph); + func_graph->set_attr("graph_name", MakeValue(graph_name)); + func_graph->set_manager(manager); + // create cnode and parameter for func_graph of partial + std::vector partial_inputs; + std::map partial_inputs_and_subgraph_input_map; + CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map); + CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map); + // add return for func_graph of partial + auto sub_graph_outputs = this->GetOutCNodes(); + MS_ASSERT(!sub_graph_outputs.empty()); + auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Set subgraph output failed"; + return ret; + } + // create partial cnode + auto partial_prim = std::make_shared(); + auto graph_value_node = NewValueNode(func_graph); + partial_inputs.insert(partial_inputs.begin(), graph_value_node); + auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs); + partial_cnode->set_fullname_with_scope(graph_name + "/partial"); + for (size_t i = 0; i < partial_inputs.size(); ++i) { + const auto &input = partial_inputs.at(i); + manager->SetEdge(partial_cnode, static_cast(i + 1), input); + } + // create call cnode + std::vector call_node_inputs{partial_cnode}; + auto call_cnode = belong_anf_->NewCNode(call_node_inputs); + call_cnode->set_fullname_with_scope(graph_name + "/call"); + // replace belong-graph's output + auto return_node = belong_anf_->get_return(); + MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2); + auto ori_output = return_node->inputs().at(1); + MS_ASSERT(ori_output != nullptr); + manager->Replace(ori_output, call_cnode); + return RET_OK; +} + +int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set &outputs) { + std::vector output_nodes; + output_nodes.insert(output_nodes.end(), outputs.begin(), outputs.end()); + return lite::SetFuncGraphOutput(graph, output_nodes); +} + +void SubGraph::CreateParameterForPartialSubGraph( + const FuncGraphPtr &sub_graph, std::vector *partial_inputs, + std::map *partial_inputs_and_subgraph_input_map) { + MS_ASSERT(sub_graph != nullptr); + MS_ASSERT(partial_inputs != nullptr && partial_inputs->empty()); + MS_ASSERT(partial_inputs_and_subgraph_input_map != nullptr && partial_inputs_and_subgraph_input_map->empty()); + + std::string graph_name = sub_graph->get_attr("graph_name")->ToString(); + for (const auto &in_cnode : this->GetInCNodes()) { + if (in_cnode == nullptr) { + continue; + } + for (size_t i = 1; i < in_cnode->inputs().size(); i++) { + auto input = in_cnode->input(i); + if (input == nullptr) { + continue; + } + auto iter = partial_inputs_and_subgraph_input_map->find(input); + if (iter != partial_inputs_and_subgraph_input_map->end()) { + continue; + } + // create subgraph input parameter from cnode and record partial inputs + if (utils::isa(input)) { + auto input_cnode = utils::cast(input); + if (this->GetNodes().count(input_cnode) > 0) { + continue; + } + partial_inputs->emplace_back(input); + auto new_parameter = sub_graph->add_parameter(); + new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope()); + new_parameter->set_abstract(input->abstract()); + (*partial_inputs_and_subgraph_input_map)[input] = new_parameter; + } + // create subgraph input parameter from parameter and record partial inputs + // add parameter to func_graph + auto node_users = this->belong_anf_->manager()->node_users(); + if (utils::isa(input)) { + auto parameter = utils::cast(input); + // graph input: create a parameter + if (!parameter->has_default()) { + auto new_parameter = sub_graph->add_parameter(); + new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope()); + new_parameter->set_abstract(input->abstract()); + (*partial_inputs_and_subgraph_input_map)[input] = new_parameter; + partial_inputs->emplace_back(new_parameter); + } + // weight parameter, it depends + auto output_pairs_iter = node_users.find(input); + if (output_pairs_iter != node_users.end() && + output_pairs_iter->second.size() > 1) { // shared weight: create a parameter + auto new_parameter = sub_graph->add_parameter(); + new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope()); + new_parameter->set_abstract(input->abstract()); + (*partial_inputs_and_subgraph_input_map)[input] = new_parameter; + partial_inputs->emplace_back(new_parameter); + } else { // not shared weight: move into subgraph + sub_graph->AddNode(input); + input->set_func_graph(sub_graph); + this->belong_anf_->DropNode(input); + } + } + } + } +} + +void SubGraph::CreateCNodeForPartialSubGraph( + const FuncGraphPtr &sub_graph, const std::map &partial_inputs_and_subgraph_input_map) { + MS_ASSERT(sub_graph != nullptr); + // move cnode from belong_graph to subgraph + for (auto &node : this->GetNodes()) { + sub_graph->AddNode(node); + node->set_func_graph(sub_graph); + for (size_t i = 0; i < node->inputs().size(); i++) { + if (node == nullptr || node->inputs().at(i)) { + continue; + } + auto input = node->inputs().at(i); + auto iter = partial_inputs_and_subgraph_input_map.find(input); + if (iter == partial_inputs_and_subgraph_input_map.end()) { + continue; + } + // use SetEdge not set_input, if not, node_user is not updated. + this->belong_anf_->manager()->SetEdge(node, static_cast(i), iter->second); + } + this->belong_anf_->DropNode(node); + } +} + +int SubGraph::ApplySubGraph() { + // check + if (this->nodes_.empty()) { + return lite::RET_NO_CHANGE; + } + if (belong_anf_ == nullptr || belong_anf_->manager() == nullptr) { + MS_LOG(DEBUG) << "belong_anf_ or manager is nullptr"; + return lite::RET_NO_CHANGE; + } + for (const auto &node : this->nodes_) { + if (node == nullptr) { + continue; + } + if (node->func_graph() != belong_anf_) { + MS_LOG(DEBUG) << "subgraph nodes belong to different func_graph"; + return lite::RET_ERROR; + } + } + + // create after partial // redirect input of after subgraph + auto after_subgraph = this->FindAfterSubGraphInBelongAnf(); + if (after_subgraph == nullptr) { + MS_LOG(DEBUG) << "Create after subgraph failed"; + return RET_ERROR; + } + auto ret = after_subgraph->CreatePartialInBelongAnf(); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Create after partial failed"; + return RET_ERROR; + } + // merge after partial into subgraph + auto subgraph_nodes = this->nodes_; + auto return_node = belong_anf_->get_return(); + MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2); + auto call_node = return_node->inputs().at(1); + MS_ASSERT(call_node != nullptr && utils::isa(call_node)); + auto call_cnode = utils::cast(call_node); + MS_ASSERT(call_cnode != nullptr && call_cnode->inputs().size() == 1); + auto after_partial_node = call_cnode->inputs().at(0); + MS_ASSERT(after_partial_node != nullptr && utils::isa(after_partial)); + auto after_partial_cnode = utils::cast(after_partial_node); + MS_ASSERT(after_partial_cnode != nullptr); + subgraph_nodes.insert(after_partial_cnode); + subgraph_nodes.insert(call_cnode); + this->Reset(subgraph_nodes); + // create subgraph partial // add partial to main subgraph + ret = this->CreatePartialInBelongAnf(); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "Create partial failed"; + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/common/func_graph_subgraph.h b/mindspore/lite/tools/common/func_graph_subgraph.h new file mode 100644 index 00000000000..539c3db502c --- /dev/null +++ b/mindspore/lite/tools/common/func_graph_subgraph.h @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H +#define MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H + +#include +#include +#include +#include +#include +#include "src/common/log_adapter.h" +#include "include/errorcode.h" +#include "ir/anf.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +class SubGraph; +using SubGraphPtr = std::shared_ptr; +class SubGraph { + public: + explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set &head_nodes = {}); + + void Reset(const std::set &nodes, const std::set &head_nodes = {}); + + bool MergeSubGraph(const SubGraphPtr &subgraph); + + std::set GetNodes() const; + std::set GetInCNodes() const; + std::set GetOutCNodes() const; + + int ApplySubGraph(); + + private: + std::set GetInputCNodes() const; + std::set GetOutputCNodes() const; + // init subgraph methods + void InitSubGraphNode(const std::set &head_nodes); + void InitSubGraphInNode(); + void InitSubGraphOutNode(); + // merge subgraph methods + std::set FindCommonOutputs(const SubGraphPtr &subgraph) const; + bool IfDependOnSameNode(const SubGraphPtr &subgraph) const; + // apply subgraph methods + SubGraphPtr FindBeforeSubGraphInBelongAnf() const; + SubGraphPtr FindAfterSubGraphInBelongAnf() const; + void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector *partial_inputs, + std::map *partial_inputs_and_subgraph_input_map); + void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph, + const std::map &partial_inputs_and_subgraph_input_map); + int CreatePartialInBelongAnf(); + static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set &outputs); + + private: + std::set nodes_; + std::set in_nodes_; + std::set out_nodes_; + const FuncGraphPtr belong_anf_; + const std::string name_; +}; +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 9e9c1ba552c..2e6407a63cf 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -26,6 +26,7 @@ #include "tools/common/node_util.h" #include "src/common/log_adapter.h" #include "src/common/utils.h" +#include "tools/converter/ops/ops_def.h" namespace mindspore { namespace lite { @@ -33,6 +34,29 @@ namespace { enum QuantBitNum { QuantBitNum_INT8 = 8, QuantBitNum_INT16 = 16 }; const int kZeroPointGap = 128; } // namespace +int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector &outputs) { + if (graph == nullptr || outputs.empty()) { + MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty"; + return RET_INPUT_PARAM_INVALID; + } + if (outputs.size() == 1) { + graph->set_output(outputs.front(), false); + return RET_OK; + } + auto make_tuple_prim_ptr = std::make_shared(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(DEBUG) << "new MakeTuple failed"; + return lite::RET_NULL_PTR; + } + auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_ptr, outputs); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(DEBUG) << "new cnode failed"; + return lite::RET_NULL_PTR; + } + make_tuple_cnode->set_fullname_with_scope("return tuple"); + graph->set_output(make_tuple_cnode, false); + return RET_OK; +} OpDefCopyer GetSimpleOpCopyer() { return [](CNodeT *inCNode) -> std::unique_ptr { diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h index 1fc3f60dbf0..720b9111085 100644 --- a/mindspore/lite/tools/common/graph_util.h +++ b/mindspore/lite/tools/common/graph_util.h @@ -46,6 +46,8 @@ using OpDefCopyer = std::function(schema::CNodeT OpDefCopyer GetSimpleOpCopyer(); +int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector &outputs); + std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1); std::vector GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node, diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index be8b204c342..65d6a8659e9 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -28,147 +28,19 @@ namespace mindspore { namespace lite { constexpr size_t kInitialSize = 1024; - -static const std::vector nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion, - schema::PrimitiveType_Conv2DBackpropInputFusion, - schema::PrimitiveType_AvgPoolGrad, - schema::PrimitiveType_MaxPoolGrad, - schema::PrimitiveType_BiasAddGrad, - schema::PrimitiveType_BatchNormGrad, - schema::PrimitiveType_ApplyMomentum, - schema::PrimitiveType_SGD, - schema::PrimitiveType_Adam, - schema::PrimitiveType_ResizeGrad, - schema::PrimitiveType_AvgPoolFusion, - schema::PrimitiveType_MaxPoolFusion, - schema::PrimitiveType_Conv2DFusion, - schema::PrimitiveType_Conv2dTransposeFusion, - schema::PrimitiveType_LRN, - schema::PrimitiveType_Resize, - schema::PrimitiveType_BatchNorm, - schema::PrimitiveType_FusedBatchNorm, - schema::PrimitiveType_PReLUFusion, - schema::PrimitiveType_BiasAdd, - schema::PrimitiveType_SpaceToDepth, - schema::PrimitiveType_DepthToSpace, - schema::PrimitiveType_TopKFusion, - schema::PrimitiveType_BatchToSpace, - schema::PrimitiveType_SpaceToBatch, - schema::PrimitiveType_SpaceToBatchND}; - -static const std::vector nchwOpList = {schema::PrimitiveType_InstanceNorm}; - -static const std::vector nhwcOpAllInputList = { - schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, - schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, - schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad}; - -// index {} mean all inputs need insert -static std::unordered_map> extNhwcInsertIndex = { - {schema::PrimitiveType_BatchNormGrad, {0, 1}}, - {schema::PrimitiveType_Conv2DBackpropFilterFusion, {0, 1}}, - {schema::PrimitiveType_ApplyMomentum, {3}}, - {schema::PrimitiveType_SGD, {1}}, - {schema::PrimitiveType_Adam, {9}}}; - -static const std::vector fp32FullOpList = { - schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion, - schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32 - -static const std::vector int8NeedNhwcOpList = {}; - -static const std::vector int8OpList = {schema::PrimitiveType_Conv2DFusion, - schema::PrimitiveType_Conv2dTransposeFusion, - schema::PrimitiveType_AddFusion, - schema::PrimitiveType_Transpose, - schema::PrimitiveType_AvgPoolFusion, - schema::PrimitiveType_MaxPoolFusion, - schema::PrimitiveType_Concat, - schema::PrimitiveType_Softmax, - schema::PrimitiveType_Reshape, - schema::PrimitiveType_Activation, - schema::PrimitiveType_Resize, - schema::PrimitiveType_FullConnection, - schema::PrimitiveType_ArgMaxFusion, - schema::PrimitiveType_ArgMinFusion, - schema::PrimitiveType_BatchNorm, - schema::PrimitiveType_FusedBatchNorm, - schema::PrimitiveType_BiasAdd, - schema::PrimitiveType_DivFusion, - schema::PrimitiveType_MulFusion, - schema::PrimitiveType_SliceFusion, - schema::PrimitiveType_Split, - schema::PrimitiveType_Squeeze, - schema::PrimitiveType_SubFusion, - schema::PrimitiveType_StridedSlice, - schema::PrimitiveType_TopKFusion, - schema::PrimitiveType_Unsqueeze, - schema::PrimitiveType_MatMul, - schema::PrimitiveType_PadFusion, - schema::PrimitiveType_ScaleFusion, - schema::PrimitiveType_Cast, - schema::PrimitiveType_Shape, - schema::PrimitiveType_ExpandDims, - schema::PrimitiveType_BatchToSpace, - schema::PrimitiveType_BatchToSpaceND, - schema::PrimitiveType_ReduceFusion, - schema::PrimitiveType_Round, - schema::PrimitiveType_Floor, - schema::PrimitiveType_Ceil, - schema::PrimitiveType_Abs, - schema::PrimitiveType_Sin, - schema::PrimitiveType_Cos, - schema::PrimitiveType_Log, - schema::PrimitiveType_Sqrt, - schema::PrimitiveType_Rsqrt, - schema::PrimitiveType_Square, - schema::PrimitiveType_LogicalNot, - schema::PrimitiveType_SpaceToBatch, - schema::PrimitiveType_SpaceToBatchND, - schema::PrimitiveType_DepthToSpace, - schema::PrimitiveType_PowFusion, - schema::PrimitiveType_GatherNd, - schema::PrimitiveType_LeakyRelu, - schema::PrimitiveType_Gather, - schema::PrimitiveType_Equal, - schema::PrimitiveType_NotEqual, - schema::PrimitiveType_LessEqual, - schema::PrimitiveType_Greater, - schema::PrimitiveType_GreaterEqual, - schema::PrimitiveType_Eltwise, - schema::PrimitiveType_DetectionPostProcess, - schema::PrimitiveType_Crop, - schema::PrimitiveType_PriorBox, - schema::PrimitiveType_QuantDTypeCast, - schema::PrimitiveType_LayerNormFusion, - schema::PrimitiveType_L2NormalizeFusion}; - -static const std::vector needInsertOpList = { - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, - schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion, - schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion, - schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum, - schema::PrimitiveType_ActivationGrad}; - -static const std::unordered_map nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; - -std::unordered_map GetNc2NhAxisMap() { return nc2NhAxisMap; } - -std::vector GetInsertOpList() { return needInsertOpList; } - -std::vector Getfp32FullOpList() { return fp32FullOpList; } - -std::vector GetNhwcOpList() { return nhwcOpList; } - -std::vector GetNchwOpList() { return nchwOpList; } - -std::unordered_map> GetExtNhwcIndexes() { return extNhwcInsertIndex; } - -std::vector GetNhwcAllInputOpList() { return nhwcOpAllInputList; } - -std::vector GetUint8NhwcOpList() { return int8NeedNhwcOpList; } - -std::vector GetInt8OpList() { return int8OpList; } +std::vector GetInputCNode(const CNodePtr &cnode) { + if (cnode == nullptr) { + return {}; + } + std::vector inputs; + for (const auto &input : cnode->inputs()) { + if (input == nullptr || !utils::isa(input)) { + continue; + } + inputs.emplace_back(utils::cast(input)); + } + return inputs; +} const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) { if (primitive_t == nullptr || fbb == nullptr) { diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 7fcba451927..6a2f1a560ae 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -31,6 +31,8 @@ namespace mindspore { namespace lite { +std::vector GetInputCNode(const CNodePtr &cnode); + template int CreateOperator(const std::unique_ptr &primitive, schema::PrimitiveType type) { auto attr = std::make_unique(); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index ddffc0809fc..7ff34f83de9 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -27,6 +27,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/func_graph_subgraph.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc @@ -112,6 +113,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/transpose_strategy.cc ../optimizer/graph/reduce_same_act_pass.cc ../optimizer/graph/split_one_pass.cc + ../optimizer/graph/find_const_subgraph_pass.cc ) add_subdirectory(../anf_exporter anf_exporter)