delete identity node

This commit is contained in:
yefeng 2023-03-01 10:06:22 +08:00
parent 4b17adebf8
commit ec4051d3ae
2 changed files with 34 additions and 8 deletions

View File

@ -217,6 +217,20 @@ std::vector<AnfNodePtr> GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c
}
return inputs;
}
bool HasSubgraph(const std::shared_ptr<AnfGraph> &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto sub_graph = GetCNodeFuncGraph(node);
if (sub_graph != nullptr) {
return true;
}
}
return false;
}
} // namespace
// ---------------implement of DfGraphConvertor-------------
@ -1543,7 +1557,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
MS_LOG(INFO) << "Set graph input num: " << inputs.size();
(void)df_graph_->SetInputs(inputs);
SetGraphOutputs();
SetGraphOutputs(true);
(void)df_graph_->SetOutputs(graph_outputs_);
IdentityOptimization();
@ -1558,15 +1572,27 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
return *this;
}
void DfGraphConvertor::SetGraphOutputs() {
void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) {
if (cur_while_node_ == nullptr) {
graph_outputs_.clear();
std::vector<AnfNodePtr> return_nodes;
auto ret_node = anf_graph_->get_return();
auto adpt = FindAdapter(ret_node, training_);
MS_EXCEPTION_IF_NULL(adpt);
auto handles = adpt->getOutputs(Convert(ret_node));
for (const auto &handle : handles) {
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
// replace return node with graph output node.
if (!HasSubgraph(anf_graph_) && is_main_graph) {
auto output_nodes = ret_node->inputs();
return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
} else {
return_nodes.push_back(ret_node);
}
for (size_t i = 0; i < return_nodes.size(); i++) {
auto output_node = return_nodes[i];
MS_EXCEPTION_IF_NULL(output_node);
auto adpt = FindAdapter(output_node, training_);
MS_EXCEPTION_IF_NULL(adpt);
auto handles = adpt->getOutputs(Convert(output_node));
for (const auto &handle : handles) {
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
}
}
}

View File

@ -250,7 +250,7 @@ class DfGraphConvertor {
void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
const CNodePtr &from_node_input);
void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input);
void SetGraphOutputs();
void SetGraphOutputs(bool is_main_graph = false);
std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input);
// Identity Optimization