Optimize KernelGraph::ReplaceNode

Remove unnecessary UpdateNodeEdgeList() calls.
This commit is contained in:
He Wei 2021-09-28 20:19:48 +08:00
parent 7476deba66
commit be1ea10ca7
1 changed files with 11 additions and 25 deletions

View File

@ -912,34 +912,20 @@ void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNo
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) { void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
MS_EXCEPTION_IF_NULL(inputs_); MS_EXCEPTION_IF_NULL(inputs_);
{
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
}
auto it = node_output_edges_.find(old_anf_node); auto it = node_output_edges_.find(old_anf_node);
if (it != node_output_edges_.end()) { if (it == node_output_edges_.end()) {
const auto &outputs = it->second; MS_LOG(WARNING) << "Old node not found " << old_anf_node->DebugString();
for (auto &output_node : outputs) { return;
MS_EXCEPTION_IF_NULL(output_node.first); }
auto output_cnode = output_node.first->cast<CNodePtr>(); for (auto &user : it->second) {
MS_EXCEPTION_IF_NULL(output_cnode); auto user_cnode = dyn_cast<CNode>(user.first);
auto &output_node_inputs = output_cnode->inputs(); MS_EXCEPTION_IF_NULL(user_cnode);
// don't replace node if it is a control edge => output_node.second == 0 auto &inputs = user_cnode->inputs();
if (output_node.second == 0) { for (size_t i = 1; i < inputs.size(); i++) {
continue; if (inputs[i] == old_anf_node) {
} user_cnode->set_input(i, new_anf_node);
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node);
}
} }
} }
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
}
{
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
} }
} }