From 58403932ee5d90a6b220c5858647c865dc331e0d Mon Sep 17 00:00:00 2001 From: gukecai Date: Thu, 11 Jun 2020 21:10:34 +0800 Subject: [PATCH] add sync bewteen hcom --- .../device/ascend/ascend_stream_assign.cc | 71 +++++++++++++++++++ .../device/ascend/ascend_stream_assign.h | 1 + 2 files changed, 72 insertions(+) diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 26ab826a7f..10d98856ec 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector fusion_hcom_index; + vector orders; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (IsHcom(cur_cnode)) { + fusion_hcom_index.emplace_back(i); + } + } + + if (fusion_hcom_index.size() < 2) { + MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them"; + return; + } + + uint32_t first_index = fusion_hcom_index[0]; + uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1]; + + uint32_t cur_event_id = total_event_num_; + uint32_t pre_hcom_stream_id = UINT32_MAX; + std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders)); + for (size_t i = first_index; i <= last_index; i++) { + auto cur_cnode = cnode_ptr_list[i]; + auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i); + if (it == fusion_hcom_index.end()) { + orders.emplace_back(cur_cnode); + continue; + } + + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (cur_hcom_stream_id == pre_hcom_stream_id) { + orders.emplace_back(cur_cnode); + continue; + } + + if (i == first_index) { + // first fusion hcom + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else if (i == last_index) { + // last fusion hcom + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + orders.emplace_back(cur_cnode); + cur_event_id++; + } else { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id++; + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } + + pre_hcom_stream_id = cur_hcom_stream_id; + } + + std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); + graph_ptr->set_execution_order(orders); + total_event_num_ = cur_event_id; + MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]"; + MS_LOG(INFO) << "end"; +} + void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptrset_execution_order(cnodes); total_event_num_ = cur_event_id; MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; + + // Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2) + InsertSendRecvForDiffHcom(graph_ptr); MS_LOG(INFO) << "end"; } diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index 7728e61fb0..4bb55a3d21 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -95,6 +95,7 @@ class AscendStreamAssign { void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); + void InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr); void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); void ReorderIndependentOrders(const std::shared_ptr &graph_ptr);