Optimize KernelGraph::ReplaceNode

Remove unnecessary UpdateNodeEdgeList() calls.
This commit is contained in:
He Wei 2021-09-28 20:19:48 +08:00
parent 7476deba66
commit be1ea10ca7
1 changed files with 11 additions and 25 deletions

View File

@ -912,34 +912,20 @@ void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNo
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
MS_EXCEPTION_IF_NULL(inputs_);
{
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
}
auto it = node_output_edges_.find(old_anf_node);
if (it != node_output_edges_.end()) {
const auto &outputs = it->second;
for (auto &output_node : outputs) {
MS_EXCEPTION_IF_NULL(output_node.first);
auto output_cnode = output_node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto &output_node_inputs = output_cnode->inputs();
// don't replace node if it is a control edge => output_node.second == 0
if (output_node.second == 0) {
continue;
}
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node);
}
if (it == node_output_edges_.end()) {
MS_LOG(WARNING) << "Old node not found " << old_anf_node->DebugString();
return;
}
for (auto &user : it->second) {
auto user_cnode = dyn_cast<CNode>(user.first);
MS_EXCEPTION_IF_NULL(user_cnode);
auto &inputs = user_cnode->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
if (inputs[i] == old_anf_node) {
user_cnode->set_input(i, new_anf_node);
}
}
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
}
{
std::queue<AnfNodePtr> seed_nodes;
UpdateNodeEdgeList(&seed_nodes);
}
}