diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 241485b357e..bfc0809a114 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -196,7 +196,7 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { auto cnode_ptr_list = graph_ptr->execution_order(); - std::map> graph_nodes_map; + std::map>> group_graph_nodes_map; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; // node has been assigned stream before @@ -205,27 +205,52 @@ void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { } if (IsHcom(cur_cnode_ptr)) { + if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) { + MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode_ptr->DebugString() << " has no group attr"; + } + auto group_name = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrGroup); auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get()); - auto it = graph_nodes_map.find(hcom_graph_id); - if (it == graph_nodes_map.end()) { + auto iter = group_graph_nodes_map.find(group_name); + if (iter == group_graph_nodes_map.end()) { + std::map> graph_nodes_map; graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr}; + group_graph_nodes_map[group_name] = graph_nodes_map; } else { - it->second.emplace_back(cur_cnode_ptr); + auto &graph_nodes_map = iter->second; + auto it = graph_nodes_map.find(hcom_graph_id); + if (it == graph_nodes_map.end()) { + graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr}; + } else { + it->second.emplace_back(cur_cnode_ptr); + } } } } - MS_LOG(INFO) << "hcom diff graph id size:" << graph_nodes_map.size(); - for (const auto &item : graph_nodes_map) { - bool new_graph = true; - auto graph_id = item.first; - hcom_graph_map_[graph_id] = {}; - for (const auto &hcom_node_ptr : item.second) { - auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph); - hcom_graph_map_[graph_id].emplace(assigned_stream_id); - new_graph = false; - } + + MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size(); + for (const auto &item : group_graph_nodes_map) { + MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size(); + } + + for (const auto &diff_group : group_graph_nodes_map) { + // group id: + std::map> hcom_graph_map; + for (const auto &item : diff_group.second) { + bool new_graph = true; + auto graph_id = item.first; + hcom_graph_map[graph_id] = {}; + for (const auto &hcom_node_ptr : item.second) { + auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph); + hcom_graph_map[graph_id].emplace(assigned_stream_id); + new_graph = false; + } + } + group_hcom_graph_map_[diff_group.first] = hcom_graph_map; + } + + for (const auto &item : group_hcom_graph_map_) { + MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size(); } - MS_LOG(INFO) << "hcom stream nums : " << hcom_stream_map_.size(); } uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { @@ -337,7 +362,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull &graph } void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull &graph_ptr) { - if (hcom_graph_map_.empty() && independent_graph_map_.empty()) { + if (group_hcom_graph_map_.empty() && independent_graph_map_.empty()) { MS_LOG(INFO) << "Hcom and independent is empty"; return; } @@ -347,19 +372,32 @@ void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull> other_graph; - for (const auto &item : hcom_graph_map_) { - MS_LOG(INFO) << "Graph id:" << item.first; - if (item.first == root_graph_id) { - if (loop_sink_) { - ActiveRootGraphHcom(graph_ptr, item.second); + std::set hcom_streams; + for (const auto &graph_nodes : group_hcom_graph_map_) { + for (const auto &item : graph_nodes.second) { + MS_LOG(INFO) << "Graph id:" << item.first; + if (item.first == root_graph_id) { + if (loop_sink_) { + hcom_streams.insert(item.second.begin(), item.second.end()); + } + } else { + auto it = other_graph.find(item.first); + if (it == other_graph.end()) { + other_graph[item.first] = item.second; + } else { + for (const auto &stream : item.second) { + it->second.emplace(stream); + } + } } - } else { - other_graph[item.first] = item.second; } } + if (!hcom_streams.empty()) { + ActiveRootGraphHcom(graph_ptr, hcom_streams); + } + MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size(); for (const auto &item : independent_graph_map_) { MS_LOG(DEBUG) << "Graph id:" << item.first; @@ -505,7 +543,6 @@ void AscendStreamAssign::ActiveRootGraphIndependent(const NotNullset_execution_order(update_cnode_list); } - void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull &graph_ptr) { MS_LOG(INFO) << "Start"; GetProcessedStream(graph_ptr); @@ -733,7 +770,7 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { MS_LOG(INFO) << "Start"; InsertEventCommonDependHcom(graph_ptr); - InsertEventHcomDependCommon(graph_ptr); + InsertEventHcomDependCommonBak(graph_ptr); InsertEventHcomDependHcom(graph_ptr); MS_LOG(INFO) << "End"; } @@ -777,36 +814,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr, - const CNodePtr &cur_cnode_ptr) { - auto cnode_ptr_list = graph_ptr->execution_order(); - auto &inputs = cur_cnode_ptr->inputs(); - auto it_pos = cnode_ptr_list.begin(); - for (size_t i = 1; i < inputs.size(); i++) { - if (inputs[i]->isa()) { - auto cnode = inputs[i]->cast(); - while (opt::IsNopNode(cnode)) { - cnode = cnode->inputs()[1]->cast(); - } - - auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); - if (it != cnode_ptr_list.end()) { - it_pos = it; - } - } else { - continue; - } - } - - if (it_pos == cnode_ptr_list.begin() && *it_pos != inputs[1]) { - MS_LOG(EXCEPTION) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; - } - - MS_LOG(INFO) << "The las input of node:" << cur_cnode_ptr->DebugString() << " is:" << (*it_pos)->fullname_with_scope() - << "; name:" << (*it_pos)->DebugString(); - return *it_pos; -} - // after memory reuse is correct, use this function void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull &graph_ptr) { AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); @@ -830,7 +837,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull &graph_ptr, + const CNodePtr &cur_cnode_ptr) { + auto cnode_ptr_list = graph_ptr->execution_order(); + auto input_cnodes = GetInputKernels(cur_cnode_ptr); + if (input_cnodes.empty()) { + return nullptr; + } + auto it_pos = cnode_ptr_list.begin(); + + for (auto &cnode : input_cnodes) { + auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); + if (it != cnode_ptr_list.end()) { + it_pos = it; + } + } + if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) { + MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; + } + + return *it_pos; +} + +vector AscendStreamAssign::GetInputKernels(const CNodePtr &node) { + vector input_cnodes; + queue nop_nodes; + auto inputs = node->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + auto real_input = AnfAlgo::VisitKernel(inputs[i], 0); + auto node = real_input.first; + if (opt::IsNopNode(node)) { + nop_nodes.push(node->cast()); + while (!nop_nodes.empty()) { + auto cur_node = nop_nodes.front(); + nop_nodes.pop(); + auto new_inputs = cur_node->inputs(); + for (size_t j = 1; j < new_inputs.size(); j++) { + auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); + auto new_node = new_real_input.first; + if (opt::IsNopNode(new_node)) { + nop_nodes.push(new_node->cast()); + } else if (new_node->isa()) { + input_cnodes.emplace_back(new_node->cast()); + } + } + } + } else if (node->isa()) { + input_cnodes.emplace_back(node->cast()); + } + } + return input_cnodes; +} + void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); @@ -896,40 +955,70 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); - uint32_t first_hcom_stream = kInvalidStreamId; - uint32_t last_hcom_stream = kInvalidStreamId; - // key: stream id, value:hcom index - std::map> hcom_index; + // key:group id, key: stream id, value:hcom index + std::map>> group_hcom_index; + std::map group_first_hcom_stream; + std::map group_last_hcom_stream; for (size_t i = 0; i < cnode_ptr_list.size(); i++) { auto cur_cnode = cnode_ptr_list[i]; if (!IsHcom(cur_cnode)) { continue; } uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - auto it = hcom_index.find(cur_stream_id); - if (it != hcom_index.end()) { - hcom_index[cur_stream_id].emplace_back(i); - } else { + if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { + MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; + } + auto group_name = AnfAlgo::GetNodeAttr(cur_cnode, kAttrGroup); + auto iter = group_hcom_index.find(group_name); + if (iter == group_hcom_index.end()) { + std::map> hcom_index; hcom_index[cur_stream_id] = {i}; + group_hcom_index[group_name] = hcom_index; + } else { + auto &hcom_index = iter->second; + auto it = hcom_index.find(cur_stream_id); + if (it == hcom_index.end()) { + hcom_index[cur_stream_id] = {i}; + } else { + it->second.emplace_back(i); + } } // record first hcom stream id - if (first_hcom_stream == kInvalidStreamId) { - first_hcom_stream = cur_stream_id; + auto it = group_first_hcom_stream.find(group_name); + if (it == group_first_hcom_stream.end()) { + group_first_hcom_stream[group_name] = cur_stream_id; } // record last hcom stream id - if (cur_stream_id != last_hcom_stream) { - last_hcom_stream = cur_stream_id; + it = group_last_hcom_stream.find(group_name); + if (it != group_last_hcom_stream.end()) { + it->second = cur_stream_id; + } else { + group_last_hcom_stream[group_name] = cur_stream_id; } } - if (hcom_index.size() < 2) { - MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; - return; + for (const auto &hcom_index : group_hcom_index) { + if (hcom_index.second.size() < 2) { + MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; + return; + } + auto group_name = hcom_index.first; + auto it = group_first_hcom_stream.find(group_name); + if (it == group_first_hcom_stream.end()) { + MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name; + } + auto first_hcom_stream = it->second; + + it = group_last_hcom_stream.find(group_name); + if (it == group_last_hcom_stream.end()) { + MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name; + } + auto last_hcom_stream = it->second; + InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream); + MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); } - InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); - MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); } void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, @@ -1199,9 +1288,12 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull &gra // 3)hcom stream:if has not been activate, push to need active vector if (!hcom_stream_activated_) { - auto it = hcom_graph_map_.find(root_graph_id); - if (it != hcom_graph_map_.end()) { - std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_)); + for (const auto &item : group_hcom_graph_map_) { + auto &hcom_graph_map = item.second; + auto it = hcom_graph_map.find(root_graph_id); + if (it != hcom_graph_map.end()) { + std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_)); + } } } @@ -1434,7 +1526,7 @@ void AscendStreamAssign::Reset() { event_map_.clear(); independent_targets_.clear(); independent_graph_map_.clear(); - hcom_graph_map_.clear(); + group_hcom_graph_map_.clear(); middle_active_streams_.clear(); } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index 34c1e41ca1e..c8f6979c342 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -35,6 +36,7 @@ namespace mindspore { namespace device { namespace ascend { using std::map; +using std::queue; using std::shared_ptr; using std::unordered_map; using std::unordered_set; @@ -184,6 +186,7 @@ class AscendStreamAssign { void PrintStreamGroups(); void FindEventRelations(const NotNull &graph_ptr); bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const; + vector GetInputKernels(const CNodePtr &node); bool independent_stream_activated_{false}; bool hcom_stream_activated_{false}; @@ -195,8 +198,9 @@ class AscendStreamAssign { std::set processed_streams_{}; std::vector need_first_active_streams_{}; std::set independent_targets_; + + std::map>> group_hcom_graph_map_; // key:graph id, value:stream set - std::map> hcom_graph_map_; std::map> independent_graph_map_; // attr for memory copy reuse