!1598 reorder DropoutGenMask nodes for stream parallel

Merge pull request !1598 from gukecai/reorder-genmask
This commit is contained in:
mindspore-ci-bot 2020-05-28 21:59:38 +08:00 committed by Gitee
commit fe6b279cee
2 changed files with 63 additions and 0 deletions

View File

@ -531,6 +531,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
if (IsTaskSink()) {
ResetNew();
ReorderIndependentOrders(graph_ptr);
AssignAllNodesStream(graph_ptr);
FindAllReduceParallel(graph_ptr);
InsertActiveNew(graph_ptr);
@ -748,6 +749,67 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
}
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_EXCEPTION_IF_NULL(graph_ptr);
CNodePtr cur_cnode_ptr = nullptr;
std::vector<CNodePtr> exe_orders;
std::vector<CNodePtr> independents;
std::vector<CNodePtr> others;
auto cnode_ptr_list = graph_ptr->execution_order();
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (IsIndependentNode(cur_cnode_ptr)) {
independents.emplace_back(cur_cnode_ptr);
} else {
others.emplace_back(cur_cnode_ptr);
}
}
if (others.empty()) {
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
graph_ptr->set_execution_order(exe_orders);
return;
}
if (independents.empty()) {
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
graph_ptr->set_execution_order(exe_orders);
return;
}
std::vector<CNodePtr> processed;
for (size_t i = 0; i < others.size(); i++) {
auto begin = others.begin() + i;
auto end = begin + 1;
bool flag = false;
for (size_t j = 0; j < independents.size(); j++) {
auto cur_independent = independents[j];
auto it = std::find(processed.begin(), processed.end(), cur_independent);
if (it != processed.end()) {
continue;
}
auto res = FindTargetOp(begin, end, cur_independent);
if (res != end) {
flag = true;
exe_orders.emplace_back(cur_independent);
exe_orders.emplace_back(*begin);
processed.emplace_back(cur_independent);
break;
}
}
if (!flag) {
exe_orders.emplace_back(*begin);
}
}
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
graph_ptr->set_execution_order(exe_orders);
}
void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_EXCEPTION_IF_NULL(graph_ptr);

View File

@ -97,6 +97,7 @@ class AscendStreamAssign {
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0};