!1598 reorder DropoutGenMask nodes for stream parallel
Merge pull request !1598 from gukecai/reorder-genmask
This commit is contained in:
commit
fe6b279cee
|
@ -531,6 +531,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
|
|||
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
if (IsTaskSink()) {
|
||||
ResetNew();
|
||||
ReorderIndependentOrders(graph_ptr);
|
||||
AssignAllNodesStream(graph_ptr);
|
||||
FindAllReduceParallel(graph_ptr);
|
||||
InsertActiveNew(graph_ptr);
|
||||
|
@ -748,6 +749,67 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
|
|||
}
|
||||
|
||||
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
|
||||
void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
std::vector<CNodePtr> exe_orders;
|
||||
std::vector<CNodePtr> independents;
|
||||
std::vector<CNodePtr> others;
|
||||
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||
independents.emplace_back(cur_cnode_ptr);
|
||||
} else {
|
||||
others.emplace_back(cur_cnode_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (others.empty()) {
|
||||
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
return;
|
||||
}
|
||||
|
||||
if (independents.empty()) {
|
||||
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> processed;
|
||||
for (size_t i = 0; i < others.size(); i++) {
|
||||
auto begin = others.begin() + i;
|
||||
auto end = begin + 1;
|
||||
bool flag = false;
|
||||
for (size_t j = 0; j < independents.size(); j++) {
|
||||
auto cur_independent = independents[j];
|
||||
auto it = std::find(processed.begin(), processed.end(), cur_independent);
|
||||
if (it != processed.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto res = FindTargetOp(begin, end, cur_independent);
|
||||
if (res != end) {
|
||||
flag = true;
|
||||
exe_orders.emplace_back(cur_independent);
|
||||
exe_orders.emplace_back(*begin);
|
||||
processed.emplace_back(cur_independent);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!flag) {
|
||||
exe_orders.emplace_back(*begin);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
}
|
||||
|
||||
void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
|
|
|
@ -97,6 +97,7 @@ class AscendStreamAssign {
|
|||
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
|
||||
uint32_t total_common_stream_num_{0};
|
||||
uint32_t total_independ_stream_num_{0};
|
||||
|
|
Loading…
Reference in New Issue