forked from mindspore-Ecosystem/mindspore
!26567 fix allreduce notify bug
Merge pull request !26567 from hwjiaorui/fix-stream-label
This commit is contained in:
commit
a17849b669
|
@ -1277,7 +1277,8 @@ bool AscendStreamAssign::ExistStreamSendAfterLastHcomNode(const NotNull<KernelGr
|
|||
auto cnodes = graph_ptr->execution_order();
|
||||
for (int64_t i = cnodes.size() - 1; i >= 0; i--) {
|
||||
if (AnfAlgo::GetGraphId(cnodes[i].get()) == graph_id && IsHcom(cnodes[i])) {
|
||||
return AnfAlgo::GetCNodeName(cnodes[i]) == kSendOpName;
|
||||
return (AnfAlgo::GetCNodeName(cnodes[i]) == kSendOpName) ||
|
||||
((i < SizeToLong(cnodes.size() - 1)) && AnfAlgo::GetCNodeName(cnodes[i + 1]) == kSendOpName);
|
||||
}
|
||||
}
|
||||
MS_LOG(WARNING) << "There is no hcom nodes of graph " << graph_id << " in the root graph " << graph_ptr->graph_id();
|
||||
|
@ -1305,24 +1306,29 @@ void AscendStreamAssign::GraphLoopSync(const NotNull<KernelGraphPtr> &root_graph
|
|||
}
|
||||
}
|
||||
|
||||
std::set<std::string> ending_nodes = {kStreamActiveOpName, kLabelGotoOpName, kLabelSetOpName};
|
||||
// insert StreamRecv node before the last node: active, labelgoto, labelset.
|
||||
std::set<std::string> ending_nodes = {kStreamActiveOpName, kLabelGotoOpName};
|
||||
// insert StreamRecv node before the last node in the graph if the node is <StreamActive, LabelGoto> or insert
|
||||
// StreamRecv node after the last node, at the same time, the next node of the last not in the graph is LabelSet.
|
||||
for (auto iter = cnodes.end() - 1; iter >= cnodes.begin(); iter--) {
|
||||
if (AnfAlgo::GetGraphId((*iter).get()) != graph_id) {
|
||||
continue;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(*iter);
|
||||
auto cnode = (*iter)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CNodePtr recv_cnode = CreateRecvApplyKernel(root_graph, cur_event_id, AnfAlgo::GetStreamId(cnode));
|
||||
if (ending_nodes.find(node_name) != ending_nodes.end()) {
|
||||
auto cnode = (*iter)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CNodePtr recv_cnode = CreateRecvApplyKernel(root_graph, cur_event_id, AnfAlgo::GetStreamId(cnode));
|
||||
MS_LOG(INFO) << "Insert StreamRecv " << cur_event_id << " before node: " << (*iter)->fullname_with_scope();
|
||||
iter = cnodes.insert(iter, recv_cnode);
|
||||
break;
|
||||
} else if ((iter < cnodes.end() - 1) && AnfAlgo::GetCNodeName(*(iter + 1)) == kLabelSetOpName) {
|
||||
MS_LOG(INFO) << "Insert StreamRecv " << cur_event_id << "after node: " << (*iter)->fullname_with_scope();
|
||||
iter = cnodes.insert(iter + 1, recv_cnode);
|
||||
break;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The last node of graph " << graph_id
|
||||
<< " is not in the set <StreamActive, LabelGoto, LabelSet>, whereas is "
|
||||
<< (*iter)->fullname_with_scope();
|
||||
<< " is not in the set <StreamActive, LabelGoto>, whereas is " << (*iter)->fullname_with_scope()
|
||||
<< ", and check whether the next node exists and is LabelSet.";
|
||||
}
|
||||
}
|
||||
root_graph->set_execution_order(cnodes);
|
||||
|
|
Loading…
Reference in New Issue