forked from OSSInnovation/mindspore
add sync bewteen hcom
This commit is contained in:
parent
3277a63e7d
commit
58403932ee
|
@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
|
|||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
vector<uint32_t> fusion_hcom_index;
|
||||
vector<CNodePtr> orders;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
if (IsHcom(cur_cnode)) {
|
||||
fusion_hcom_index.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (fusion_hcom_index.size() < 2) {
|
||||
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t first_index = fusion_hcom_index[0];
|
||||
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
|
||||
|
||||
uint32_t cur_event_id = total_event_num_;
|
||||
uint32_t pre_hcom_stream_id = UINT32_MAX;
|
||||
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
|
||||
for (size_t i = first_index; i <= last_index; i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i);
|
||||
if (it == fusion_hcom_index.end()) {
|
||||
orders.emplace_back(cur_cnode);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||
if (cur_hcom_stream_id == pre_hcom_stream_id) {
|
||||
orders.emplace_back(cur_cnode);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i == first_index) {
|
||||
// first fusion hcom
|
||||
orders.emplace_back(cur_cnode);
|
||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(send);
|
||||
} else if (i == last_index) {
|
||||
// last fusion hcom
|
||||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(recv);
|
||||
orders.emplace_back(cur_cnode);
|
||||
cur_event_id++;
|
||||
} else {
|
||||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(recv);
|
||||
cur_event_id++;
|
||||
orders.emplace_back(cur_cnode);
|
||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(send);
|
||||
}
|
||||
|
||||
pre_hcom_stream_id = cur_hcom_stream_id;
|
||||
}
|
||||
|
||||
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
|
||||
graph_ptr->set_execution_order(orders);
|
||||
total_event_num_ = cur_event_id;
|
||||
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
|
@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
|
|||
graph_ptr->set_execution_order(cnodes);
|
||||
total_event_num_ = cur_event_id;
|
||||
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]";
|
||||
|
||||
// Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2)
|
||||
InsertSendRecvForDiffHcom(graph_ptr);
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ class AscendStreamAssign {
|
|||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
|
||||
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
|
||||
|
|
Loading…
Reference in New Issue