forked from mindspore-Ecosystem/mindspore
!49646 [MS][LITE][Ascend] delete identity
Merge pull request !49646 from yefeng/529-copy-528
This commit is contained in:
commit
4d812f248d
|
@ -217,6 +217,20 @@ std::vector<AnfNodePtr> GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c
|
||||||
}
|
}
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasSubgraph(const std::shared_ptr<AnfGraph> &func_graph) {
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto sub_graph = GetCNodeFuncGraph(node);
|
||||||
|
if (sub_graph != nullptr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// ---------------implement of DfGraphConvertor-------------
|
// ---------------implement of DfGraphConvertor-------------
|
||||||
|
@ -1543,7 +1557,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
|
||||||
MS_LOG(INFO) << "Set graph input num: " << inputs.size();
|
MS_LOG(INFO) << "Set graph input num: " << inputs.size();
|
||||||
(void)df_graph_->SetInputs(inputs);
|
(void)df_graph_->SetInputs(inputs);
|
||||||
|
|
||||||
SetGraphOutputs();
|
SetGraphOutputs(true);
|
||||||
(void)df_graph_->SetOutputs(graph_outputs_);
|
(void)df_graph_->SetOutputs(graph_outputs_);
|
||||||
|
|
||||||
IdentityOptimization();
|
IdentityOptimization();
|
||||||
|
@ -1560,15 +1574,27 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DfGraphConvertor::SetGraphOutputs() {
|
void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) {
|
||||||
if (cur_while_node_ == nullptr) {
|
if (cur_while_node_ == nullptr) {
|
||||||
graph_outputs_.clear();
|
graph_outputs_.clear();
|
||||||
|
std::vector<AnfNodePtr> return_nodes;
|
||||||
auto ret_node = anf_graph_->get_return();
|
auto ret_node = anf_graph_->get_return();
|
||||||
auto adpt = FindAdapter(ret_node, training_);
|
// replace return node with graph output node.
|
||||||
MS_EXCEPTION_IF_NULL(adpt);
|
if (!HasSubgraph(anf_graph_) && is_main_graph) {
|
||||||
auto handles = adpt->getOutputs(Convert(ret_node));
|
auto output_nodes = ret_node->inputs();
|
||||||
for (const auto &handle : handles) {
|
return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
|
||||||
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
|
} else {
|
||||||
|
return_nodes.push_back(ret_node);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < return_nodes.size(); i++) {
|
||||||
|
auto output_node = return_nodes[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(output_node);
|
||||||
|
auto adpt = FindAdapter(output_node, training_);
|
||||||
|
MS_EXCEPTION_IF_NULL(adpt);
|
||||||
|
auto handles = adpt->getOutputs(Convert(output_node));
|
||||||
|
for (const auto &handle : handles) {
|
||||||
|
(void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -250,7 +250,7 @@ class DfGraphConvertor {
|
||||||
void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
|
void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
|
||||||
const CNodePtr &from_node_input);
|
const CNodePtr &from_node_input);
|
||||||
void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input);
|
void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input);
|
||||||
void SetGraphOutputs();
|
void SetGraphOutputs(bool is_main_graph = false);
|
||||||
std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input);
|
std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input);
|
||||||
|
|
||||||
// Identity Optimization
|
// Identity Optimization
|
||||||
|
|
Loading…
Reference in New Issue