add subgraph util for func_graph

This commit is contained in:
hangangqiang 2021-08-02 17:01:31 +08:00
parent e6e544dbc4
commit 0e62e061a2
7 changed files with 673 additions and 141 deletions

View File

@ -0,0 +1,555 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/common/func_graph_subgraph.h"
#include <set>
#include <string>
#include <vector>
#include <map>
#include <queue>
#include <utility>
#include "src/common/log_adapter.h"
#include "tools/common/node_util.h"
#include "tools/common/graph_util.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ops/fusion/partial_fusion.h"
namespace mindspore::lite {
SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set<CNodePtr> &head_nodes)
: belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {
InitSubGraphNode(head_nodes);
InitSubGraphInNode();
InitSubGraphOutNode();
}
void SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
this->nodes_ = nodes;
InitSubGraphNode(head_nodes);
InitSubGraphInNode();
InitSubGraphOutNode();
}
std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
std::set<CNodePtr> SubGraph::GetInCNodes() const { return this->in_nodes_; }
std::set<CNodePtr> SubGraph::GetInputCNodes() const {
std::set<CNodePtr> inputs;
for (const auto &in_node : in_nodes_) {
if (in_node == nullptr) {
continue;
}
auto input_cnodes = GetInputCNode(in_node);
inputs.insert(input_cnodes.begin(), input_cnodes.end());
}
return inputs;
}
std::set<CNodePtr> SubGraph::GetOutCNodes() const { return this->out_nodes_; }
std::set<CNodePtr> SubGraph::FindCommonOutputs(const SubGraphPtr &subgraph) const {
if (subgraph == nullptr) {
return {};
}
std::set<CNodePtr> outputs_this = this->GetOutputCNodes();
if (this == subgraph.get()) {
return outputs_this;
}
std::set<CNodePtr> outputs_other = subgraph->GetOutputCNodes();
std::set<CNodePtr> common_outputs;
for (const auto &output1 : outputs_this) {
if (output1 == nullptr) {
continue;
}
auto iter = outputs_other.find(output1);
if (iter == outputs_other.end()) {
continue;
}
if (GetInputCNode(output1).size() == 2) {
common_outputs.insert(output1);
}
}
return common_outputs;
}
bool SubGraph::IfDependOnSameNode(const SubGraphPtr &subgraph) const {
if (subgraph == nullptr || this == subgraph.get()) {
return false;
}
std::set<CNodePtr> inputs_this = this->GetInputCNodes();
std::set<CNodePtr> inputs_other = subgraph->GetInputCNodes();
return std::any_of(inputs_this.begin(), inputs_this.end(), [&inputs_other](const CNodePtr &input_this) {
if (input_this == nullptr) {
return false;
}
return (inputs_other.count(input_this) > 0);
});
}
std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
std::set<CNodePtr> outputs;
for (const auto &out_node : out_nodes_) {
if (out_node == nullptr) {
continue;
}
auto iter = node_users.find(out_node);
if (iter == node_users.end()) {
continue;
}
auto post_node_pairs = iter->second;
for (const auto &post_node_pair : post_node_pairs) {
auto post_node = post_node_pair.first;
if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
continue;
}
outputs.insert(utils::cast<CNodePtr>(post_node));
}
}
return outputs;
}
void SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
std::queue<CNodePtr> q;
for (const auto &head_node : head_nodes) {
if (head_node == nullptr) {
continue;
}
q.push(head_node);
}
while (!q.empty()) {
auto cur_node = q.front();
MS_ASSERT(cur_node != nullptr);
q.pop();
this->nodes_.insert(cur_node);
// check output-cnode of cur-node only depend on cur-node
auto iter = node_users.find(cur_node);
if (iter == node_users.end()) {
continue;
}
auto post_node_pairs = iter->second;
for (const auto &post_node_pair : post_node_pairs) {
auto post_node = post_node_pair.first;
if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
continue;
}
auto post_cnode = utils::cast<CNodePtr>(post_node);
// return-node should not be include into subgraph absolutely // ut
if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
continue;
}
MS_ASSERT(post_cnode != nullptr);
bool non_depend = true;
// check all inputs of output-cnode
for (const auto &input : post_cnode->inputs()) {
if (input == nullptr) {
continue;
}
// input cnode is not contained in subgraph
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
if (this->nodes_.count(input_cnode) == 0) {
non_depend = false;
break;
}
}
// input parameter is a graph input
if (utils::isa<ParameterPtr>(input)) {
auto input_parameter = utils::cast<ParameterPtr>(input);
if (!input_parameter->has_default()) {
non_depend = false;
break;
}
}
}
if (non_depend) {
q.push(post_cnode);
}
}
}
}
void SubGraph::InitSubGraphInNode() {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
this->in_nodes_.clear();
for (const auto &node : this->nodes_) {
if (node == nullptr) {
continue;
}
if (std::any_of(node->inputs().begin(), node->inputs().end(), [this, &node_users](const auto &input) {
if (input == nullptr) {
return false;
}
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
if (this->nodes_.count(input_cnode) == 0) {
return true;
}
}
// graph input or shared weight input // ut
if (utils::isa<ParameterPtr>(input)) {
auto input_parameter = utils::cast<ParameterPtr>(input);
if (!input_parameter->has_default()) {
return true;
}
auto output_pair_iter = node_users.find(input);
if (output_pair_iter != node_users.end() && output_pair_iter->second.size() > 1) {
return true;
}
}
return false;
})) {
in_nodes_.insert(node);
}
}
}
void SubGraph::InitSubGraphOutNode() {
MS_ASSERT(belong_anf_ != nullptr);
MS_ASSERT(belong_anf_->manager() != nullptr);
auto node_users = belong_anf_->manager()->node_users();
this->out_nodes_.clear();
for (const auto &node : this->nodes_) {
if (node == nullptr) {
continue;
}
auto node_users_iter = node_users.find(node);
if (node_users_iter == node_users.end()) {
continue;
}
auto node_output_pairs = node_users_iter->second;
if (!std::any_of(node_output_pairs.begin(), node_output_pairs.end(),
[this](const std::pair<AnfNodePtr, int> &output_pair) {
auto output_node = output_pair.first;
if (output_node == nullptr || !utils::isa<CNodePtr>(output_node)) {
return false;
}
// graph output // ut
if (opt::CheckPrimitiveType(output_node, prim::kPrimReturn)) {
return true;
}
auto output_cnode = utils::cast<CNodePtr>(output_node);
if (this->nodes_.count(output_cnode) == 0) {
return true;
}
return false;
}))
continue;
out_nodes_.insert(node);
}
}
bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
if (subgraph == nullptr || this == subgraph.get()) {
return false;
}
// if two subgraph has same output, and this output node has only two input cnode which exactly from two
// subgraph, we merge two subgraph, and find more post node
auto common_outputs = this->FindCommonOutputs(subgraph);
if (!common_outputs.empty()) {
auto new_nodes = this->GetNodes();
auto new_nodes2 = subgraph->GetNodes();
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
new_nodes.insert(common_outputs.begin(), common_outputs.end());
this->Reset(new_nodes, common_outputs);
return true;
}
if (this->IfDependOnSameNode(subgraph)) {
auto new_nodes = this->GetNodes();
auto new_nodes2 = subgraph->GetNodes();
new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
this->Reset(new_nodes);
return true;
}
return false;
}
// iterate node from in_nodes of current subgraph up to input of belong_anf
SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
MS_ASSERT(belong_anf_ == nullptr);
// find before subgraph's nodes
std::queue<CNodePtr> q;
std::set<CNodePtr> before_nodes;
for (const auto &node : this->GetInCNodes()) {
for (const auto &in_cnode : lite::GetInputCNode(node)) {
if (in_cnode == nullptr) {
continue;
}
q.push(in_cnode);
}
}
while (!q.empty()) {
auto cur_cnode = q.front();
MS_ASSERT(cur_cnode != nullptr);
q.pop();
before_nodes.insert(cur_cnode);
for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
q.push(in_cnode);
}
}
// construct before subgraph
auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
before_subgraph->Reset(before_nodes);
return before_subgraph;
}
// iterate node from output of belong_anf up to out_nodes of current subgraph and before subgraph
SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
MS_ASSERT(belong_anf_ == nullptr);
// find before subgraph
auto before_subgraph = this->FindBeforeSubGraphInBelongAnf();
if (before_subgraph == nullptr) {
MS_LOG(ERROR) << "Find before subgraph failed";
return nullptr;
}
// find after subgraph's nodes
std::queue<CNodePtr> q;
std::set<CNodePtr> after_nodes;
auto output_node = belong_anf_->output();
if (!utils::isa<CNodePtr>(output_node)) {
MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
return nullptr;
}
q.push(utils::cast<CNodePtr>(output_node));
auto subgraph_out_nodes = this->GetOutCNodes();
auto before_out_nodes = before_subgraph->GetOutCNodes();
while (!q.empty()) {
auto cur_cnode = q.front();
MS_ASSERT(cur_cnode != nullptr);
q.pop();
after_nodes.insert(cur_cnode);
for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
if (subgraph_out_nodes.count(in_cnode) == 0 && before_out_nodes.count(in_cnode) == 0) {
q.push(in_cnode);
}
}
}
// construct before subgraph
auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
after_subgraph->Reset(after_nodes);
return after_subgraph;
}
int SubGraph::CreatePartialInBelongAnf() {
MS_ASSERT(this->belong_anf_ != nullptr);
MS_ASSERT(this->belong_anf_->manager() != nullptr);
// determine func_graph name
std::string graph_name = this->name_;
if (graph_name.empty()) {
if (this->nodes_.empty()) {
graph_name = "subgraph";
} else {
graph_name = (*(this->nodes_.begin()))->fullname_with_scope() + "/subgraph";
}
}
// create func_graph of partial
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
auto manager = belong_anf_->manager();
manager->AddFuncGraph(func_graph);
func_graph->set_attr("graph_name", MakeValue(graph_name));
func_graph->set_manager(manager);
// create cnode and parameter for func_graph of partial
std::vector<AnfNodePtr> partial_inputs;
std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
// add return for func_graph of partial
auto sub_graph_outputs = this->GetOutCNodes();
MS_ASSERT(!sub_graph_outputs.empty());
auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Set subgraph output failed";
return ret;
}
// create partial cnode
auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
auto graph_value_node = NewValueNode(func_graph);
partial_inputs.insert(partial_inputs.begin(), graph_value_node);
auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs);
partial_cnode->set_fullname_with_scope(graph_name + "/partial");
for (size_t i = 0; i < partial_inputs.size(); ++i) {
const auto &input = partial_inputs.at(i);
manager->SetEdge(partial_cnode, static_cast<int>(i + 1), input);
}
// create call cnode
std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
call_cnode->set_fullname_with_scope(graph_name + "/call");
// replace belong-graph's output
auto return_node = belong_anf_->get_return();
MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2);
auto ori_output = return_node->inputs().at(1);
MS_ASSERT(ori_output != nullptr);
manager->Replace(ori_output, call_cnode);
return RET_OK;
}
int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs) {
std::vector<AnfNodePtr> output_nodes;
output_nodes.insert(output_nodes.end(), outputs.begin(), outputs.end());
return lite::SetFuncGraphOutput(graph, output_nodes);
}
void SubGraph::CreateParameterForPartialSubGraph(
const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
MS_ASSERT(sub_graph != nullptr);
MS_ASSERT(partial_inputs != nullptr && partial_inputs->empty());
MS_ASSERT(partial_inputs_and_subgraph_input_map != nullptr && partial_inputs_and_subgraph_input_map->empty());
std::string graph_name = sub_graph->get_attr("graph_name")->ToString();
for (const auto &in_cnode : this->GetInCNodes()) {
if (in_cnode == nullptr) {
continue;
}
for (size_t i = 1; i < in_cnode->inputs().size(); i++) {
auto input = in_cnode->input(i);
if (input == nullptr) {
continue;
}
auto iter = partial_inputs_and_subgraph_input_map->find(input);
if (iter != partial_inputs_and_subgraph_input_map->end()) {
continue;
}
// create subgraph input parameter from cnode and record partial inputs
if (utils::isa<CNodePtr>(input)) {
auto input_cnode = utils::cast<CNodePtr>(input);
if (this->GetNodes().count(input_cnode) > 0) {
continue;
}
partial_inputs->emplace_back(input);
auto new_parameter = sub_graph->add_parameter();
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
new_parameter->set_abstract(input->abstract());
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
}
// create subgraph input parameter from parameter and record partial inputs
// add parameter to func_graph
auto node_users = this->belong_anf_->manager()->node_users();
if (utils::isa<ParameterPtr>(input)) {
auto parameter = utils::cast<ParameterPtr>(input);
// graph input: create a parameter
if (!parameter->has_default()) {
auto new_parameter = sub_graph->add_parameter();
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
new_parameter->set_abstract(input->abstract());
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
partial_inputs->emplace_back(new_parameter);
}
// weight parameter, it depends
auto output_pairs_iter = node_users.find(input);
if (output_pairs_iter != node_users.end() &&
output_pairs_iter->second.size() > 1) { // shared weight: create a parameter
auto new_parameter = sub_graph->add_parameter();
new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
new_parameter->set_abstract(input->abstract());
(*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
partial_inputs->emplace_back(new_parameter);
} else { // not shared weight: move into subgraph
sub_graph->AddNode(input);
input->set_func_graph(sub_graph);
this->belong_anf_->DropNode(input);
}
}
}
}
}
void SubGraph::CreateCNodeForPartialSubGraph(
const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
MS_ASSERT(sub_graph != nullptr);
// move cnode from belong_graph to subgraph
for (auto &node : this->GetNodes()) {
sub_graph->AddNode(node);
node->set_func_graph(sub_graph);
for (size_t i = 0; i < node->inputs().size(); i++) {
if (node == nullptr || node->inputs().at(i)) {
continue;
}
auto input = node->inputs().at(i);
auto iter = partial_inputs_and_subgraph_input_map.find(input);
if (iter == partial_inputs_and_subgraph_input_map.end()) {
continue;
}
// use SetEdge not set_input, if not, node_user is not updated.
this->belong_anf_->manager()->SetEdge(node, static_cast<int>(i), iter->second);
}
this->belong_anf_->DropNode(node);
}
}
int SubGraph::ApplySubGraph() {
// check
if (this->nodes_.empty()) {
return lite::RET_NO_CHANGE;
}
if (belong_anf_ == nullptr || belong_anf_->manager() == nullptr) {
MS_LOG(DEBUG) << "belong_anf_ or manager is nullptr";
return lite::RET_NO_CHANGE;
}
for (const auto &node : this->nodes_) {
if (node == nullptr) {
continue;
}
if (node->func_graph() != belong_anf_) {
MS_LOG(DEBUG) << "subgraph nodes belong to different func_graph";
return lite::RET_ERROR;
}
}
// create after partial // redirect input of after subgraph
auto after_subgraph = this->FindAfterSubGraphInBelongAnf();
if (after_subgraph == nullptr) {
MS_LOG(DEBUG) << "Create after subgraph failed";
return RET_ERROR;
}
auto ret = after_subgraph->CreatePartialInBelongAnf();
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Create after partial failed";
return RET_ERROR;
}
// merge after partial into subgraph
auto subgraph_nodes = this->nodes_;
auto return_node = belong_anf_->get_return();
MS_ASSERT(return_node != nullptr && return_node->inputs().size() == 2);
auto call_node = return_node->inputs().at(1);
MS_ASSERT(call_node != nullptr && utils::isa<CNodePtr>(call_node));
auto call_cnode = utils::cast<CNodePtr>(call_node);
MS_ASSERT(call_cnode != nullptr && call_cnode->inputs().size() == 1);
auto after_partial_node = call_cnode->inputs().at(0);
MS_ASSERT(after_partial_node != nullptr && utils::isa<CNodePtr>(after_partial));
auto after_partial_cnode = utils::cast<CNodePtr>(after_partial_node);
MS_ASSERT(after_partial_cnode != nullptr);
subgraph_nodes.insert(after_partial_cnode);
subgraph_nodes.insert(call_cnode);
this->Reset(subgraph_nodes);
// create subgraph partial // add partial to main subgraph
ret = this->CreatePartialInBelongAnf();
if (ret != RET_OK) {
MS_LOG(DEBUG) << "Create partial failed";
return RET_ERROR;
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H
#define MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H
#include <memory>
#include <string>
#include <vector>
#include <map>
#include <set>
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore::lite {
class SubGraph;
using SubGraphPtr = std::shared_ptr<SubGraph>;
class SubGraph {
public:
explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set<CNodePtr> &head_nodes = {});
void Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
bool MergeSubGraph(const SubGraphPtr &subgraph);
std::set<CNodePtr> GetNodes() const;
std::set<CNodePtr> GetInCNodes() const;
std::set<CNodePtr> GetOutCNodes() const;
int ApplySubGraph();
private:
std::set<CNodePtr> GetInputCNodes() const;
std::set<CNodePtr> GetOutputCNodes() const;
// init subgraph methods
void InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
void InitSubGraphInNode();
void InitSubGraphOutNode();
// merge subgraph methods
std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const;
bool IfDependOnSameNode(const SubGraphPtr &subgraph) const;
// apply subgraph methods
SubGraphPtr FindBeforeSubGraphInBelongAnf() const;
SubGraphPtr FindAfterSubGraphInBelongAnf() const;
void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
int CreatePartialInBelongAnf();
static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs);
private:
std::set<CNodePtr> nodes_;
std::set<CNodePtr> in_nodes_;
std::set<CNodePtr> out_nodes_;
const FuncGraphPtr belong_anf_;
const std::string name_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_FUNC_GRAPH_SUBGRAPH_H

View File

@ -26,6 +26,7 @@
#include "tools/common/node_util.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
@ -33,6 +34,29 @@ namespace {
enum QuantBitNum { QuantBitNum_INT8 = 8, QuantBitNum_INT16 = 16 };
const int kZeroPointGap = 128;
} // namespace
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
if (graph == nullptr || outputs.empty()) {
MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
return RET_INPUT_PARAM_INVALID;
}
if (outputs.size() == 1) {
graph->set_output(outputs.front(), false);
return RET_OK;
}
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(DEBUG) << "new MakeTuple failed";
return lite::RET_NULL_PTR;
}
auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_ptr, outputs);
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(DEBUG) << "new cnode failed";
return lite::RET_NULL_PTR;
}
make_tuple_cnode->set_fullname_with_scope("return tuple");
graph->set_output(make_tuple_cnode, false);
return RET_OK;
}
OpDefCopyer GetSimpleOpCopyer() {
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {

View File

@ -46,6 +46,8 @@ using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(schema::CNodeT
OpDefCopyer GetSimpleOpCopyer();
int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, int inputIndexIdx = -1);
std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,

View File

@ -28,147 +28,19 @@
namespace mindspore {
namespace lite {
constexpr size_t kInitialSize = 1024;
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion,
schema::PrimitiveType_Conv2DBackpropInputFusion,
schema::PrimitiveType_AvgPoolGrad,
schema::PrimitiveType_MaxPoolGrad,
schema::PrimitiveType_BiasAddGrad,
schema::PrimitiveType_BatchNormGrad,
schema::PrimitiveType_ApplyMomentum,
schema::PrimitiveType_SGD,
schema::PrimitiveType_Adam,
schema::PrimitiveType_ResizeGrad,
schema::PrimitiveType_AvgPoolFusion,
schema::PrimitiveType_MaxPoolFusion,
schema::PrimitiveType_Conv2DFusion,
schema::PrimitiveType_Conv2dTransposeFusion,
schema::PrimitiveType_LRN,
schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_PReLUFusion,
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_SpaceToDepth,
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_TopKFusion,
schema::PrimitiveType_BatchToSpace,
schema::PrimitiveType_SpaceToBatch,
schema::PrimitiveType_SpaceToBatchND};
static const std::vector<schema::PrimitiveType> nchwOpList = {schema::PrimitiveType_InstanceNorm};
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad,
schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion,
schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad};
// index {} mean all inputs need insert
static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = {
{schema::PrimitiveType_BatchNormGrad, {0, 1}},
{schema::PrimitiveType_Conv2DBackpropFilterFusion, {0, 1}},
{schema::PrimitiveType_ApplyMomentum, {3}},
{schema::PrimitiveType_SGD, {1}},
{schema::PrimitiveType_Adam, {9}}};
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_AddFusion,
schema::PrimitiveType_Floor}; // fp32 ops support C4 and nhwc in fp32
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2DFusion,
schema::PrimitiveType_Conv2dTransposeFusion,
schema::PrimitiveType_AddFusion,
schema::PrimitiveType_Transpose,
schema::PrimitiveType_AvgPoolFusion,
schema::PrimitiveType_MaxPoolFusion,
schema::PrimitiveType_Concat,
schema::PrimitiveType_Softmax,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation,
schema::PrimitiveType_Resize,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_ArgMaxFusion,
schema::PrimitiveType_ArgMinFusion,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_DivFusion,
schema::PrimitiveType_MulFusion,
schema::PrimitiveType_SliceFusion,
schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze,
schema::PrimitiveType_SubFusion,
schema::PrimitiveType_StridedSlice,
schema::PrimitiveType_TopKFusion,
schema::PrimitiveType_Unsqueeze,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_PadFusion,
schema::PrimitiveType_ScaleFusion,
schema::PrimitiveType_Cast,
schema::PrimitiveType_Shape,
schema::PrimitiveType_ExpandDims,
schema::PrimitiveType_BatchToSpace,
schema::PrimitiveType_BatchToSpaceND,
schema::PrimitiveType_ReduceFusion,
schema::PrimitiveType_Round,
schema::PrimitiveType_Floor,
schema::PrimitiveType_Ceil,
schema::PrimitiveType_Abs,
schema::PrimitiveType_Sin,
schema::PrimitiveType_Cos,
schema::PrimitiveType_Log,
schema::PrimitiveType_Sqrt,
schema::PrimitiveType_Rsqrt,
schema::PrimitiveType_Square,
schema::PrimitiveType_LogicalNot,
schema::PrimitiveType_SpaceToBatch,
schema::PrimitiveType_SpaceToBatchND,
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_PowFusion,
schema::PrimitiveType_GatherNd,
schema::PrimitiveType_LeakyRelu,
schema::PrimitiveType_Gather,
schema::PrimitiveType_Equal,
schema::PrimitiveType_NotEqual,
schema::PrimitiveType_LessEqual,
schema::PrimitiveType_Greater,
schema::PrimitiveType_GreaterEqual,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_DetectionPostProcess,
schema::PrimitiveType_Crop,
schema::PrimitiveType_PriorBox,
schema::PrimitiveType_QuantDTypeCast,
schema::PrimitiveType_LayerNormFusion,
schema::PrimitiveType_L2NormalizeFusion};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_PowFusion, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_AddFusion,
schema::PrimitiveType_AddN, schema::PrimitiveType_Split, schema::PrimitiveType_SliceFusion,
schema::PrimitiveType_Crop, schema::PrimitiveType_MulFusion, schema::PrimitiveType_Maximum,
schema::PrimitiveType_ActivationGrad};
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
std::unordered_map<int, int> GetNc2NhAxisMap() { return nc2NhAxisMap; }
std::vector<schema::PrimitiveType> GetInsertOpList() { return needInsertOpList; }
std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList; }
std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
std::vector<schema::PrimitiveType> GetNchwOpList() { return nchwOpList; }
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; }
std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
if (cnode == nullptr) {
return {};
}
std::vector<CNodePtr> inputs;
for (const auto &input : cnode->inputs()) {
if (input == nullptr || !utils::isa<CNodePtr>(input)) {
continue;
}
inputs.emplace_back(utils::cast<CNodePtr>(input));
}
return inputs;
}
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
if (primitive_t == nullptr || fbb == nullptr) {

View File

@ -31,6 +31,8 @@
namespace mindspore {
namespace lite {
std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode);
template <typename T>
int CreateOperator(const std::unique_ptr<schema::PrimitiveT> &primitive, schema::PrimitiveType type) {
auto attr = std::make_unique<T>();

View File

@ -27,6 +27,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/string_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/func_graph_subgraph.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc
@ -112,6 +113,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/transpose_strategy.cc
../optimizer/graph/reduce_same_act_pass.cc
../optimizer/graph/split_one_pass.cc
../optimizer/graph/find_const_subgraph_pass.cc
)
add_subdirectory(../anf_exporter anf_exporter)