delete identity node
This commit is contained in:
parent
4b17adebf8
commit
ec4051d3ae
|
@ -217,6 +217,20 @@ std::vector<AnfNodePtr> GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c
|
|||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
bool HasSubgraph(const std::shared_ptr<AnfGraph> &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto sub_graph = GetCNodeFuncGraph(node);
|
||||
if (sub_graph != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// ---------------implement of DfGraphConvertor-------------
|
||||
|
@ -1543,7 +1557,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
|
|||
MS_LOG(INFO) << "Set graph input num: " << inputs.size();
|
||||
(void)df_graph_->SetInputs(inputs);
|
||||
|
||||
SetGraphOutputs();
|
||||
SetGraphOutputs(true);
|
||||
(void)df_graph_->SetOutputs(graph_outputs_);
|
||||
|
||||
IdentityOptimization();
|
||||
|
@ -1558,15 +1572,27 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::SetGraphOutputs() {
|
||||
void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) {
|
||||
if (cur_while_node_ == nullptr) {
|
||||
graph_outputs_.clear();
|
||||
std::vector<AnfNodePtr> return_nodes;
|
||||
auto ret_node = anf_graph_->get_return();
|
||||
auto adpt = FindAdapter(ret_node, training_);
|
||||
MS_EXCEPTION_IF_NULL(adpt);
|
||||
auto handles = adpt->getOutputs(Convert(ret_node));
|
||||
for (const auto &handle : handles) {
|
||||
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
|
||||
// replace return node with graph output node.
|
||||
if (!HasSubgraph(anf_graph_) && is_main_graph) {
|
||||
auto output_nodes = ret_node->inputs();
|
||||
return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
|
||||
} else {
|
||||
return_nodes.push_back(ret_node);
|
||||
}
|
||||
for (size_t i = 0; i < return_nodes.size(); i++) {
|
||||
auto output_node = return_nodes[i];
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
auto adpt = FindAdapter(output_node, training_);
|
||||
MS_EXCEPTION_IF_NULL(adpt);
|
||||
auto handles = adpt->getOutputs(Convert(output_node));
|
||||
for (const auto &handle : handles) {
|
||||
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -250,7 +250,7 @@ class DfGraphConvertor {
|
|||
void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
|
||||
const CNodePtr &from_node_input);
|
||||
void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input);
|
||||
void SetGraphOutputs();
|
||||
void SetGraphOutputs(bool is_main_graph = false);
|
||||
std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input);
|
||||
|
||||
// Identity Optimization
|
||||
|
|
Loading…
Reference in New Issue