forked from mindspore-Ecosystem/mindspore
!1598 reorder DropoutGenMask nodes for stream parallel
Merge pull request !1598 from gukecai/reorder-genmask
This commit is contained in:
commit
fe6b279cee
|
@ -531,6 +531,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
|
||||||
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||||
if (IsTaskSink()) {
|
if (IsTaskSink()) {
|
||||||
ResetNew();
|
ResetNew();
|
||||||
|
ReorderIndependentOrders(graph_ptr);
|
||||||
AssignAllNodesStream(graph_ptr);
|
AssignAllNodesStream(graph_ptr);
|
||||||
FindAllReduceParallel(graph_ptr);
|
FindAllReduceParallel(graph_ptr);
|
||||||
InsertActiveNew(graph_ptr);
|
InsertActiveNew(graph_ptr);
|
||||||
|
@ -748,6 +749,67 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
|
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
|
||||||
|
void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||||
|
CNodePtr cur_cnode_ptr = nullptr;
|
||||||
|
std::vector<CNodePtr> exe_orders;
|
||||||
|
std::vector<CNodePtr> independents;
|
||||||
|
std::vector<CNodePtr> others;
|
||||||
|
|
||||||
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
|
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
|
||||||
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
|
cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
|
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||||
|
independents.emplace_back(cur_cnode_ptr);
|
||||||
|
} else {
|
||||||
|
others.emplace_back(cur_cnode_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (others.empty()) {
|
||||||
|
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
|
||||||
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (independents.empty()) {
|
||||||
|
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
|
||||||
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<CNodePtr> processed;
|
||||||
|
for (size_t i = 0; i < others.size(); i++) {
|
||||||
|
auto begin = others.begin() + i;
|
||||||
|
auto end = begin + 1;
|
||||||
|
bool flag = false;
|
||||||
|
for (size_t j = 0; j < independents.size(); j++) {
|
||||||
|
auto cur_independent = independents[j];
|
||||||
|
auto it = std::find(processed.begin(), processed.end(), cur_independent);
|
||||||
|
if (it != processed.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = FindTargetOp(begin, end, cur_independent);
|
||||||
|
if (res != end) {
|
||||||
|
flag = true;
|
||||||
|
exe_orders.emplace_back(cur_independent);
|
||||||
|
exe_orders.emplace_back(*begin);
|
||||||
|
processed.emplace_back(cur_independent);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!flag) {
|
||||||
|
exe_orders.emplace_back(*begin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
|
||||||
|
graph_ptr->set_execution_order(exe_orders);
|
||||||
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||||
|
|
|
@ -97,6 +97,7 @@ class AscendStreamAssign {
|
||||||
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
|
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||||
|
|
||||||
uint32_t total_common_stream_num_{0};
|
uint32_t total_common_stream_num_{0};
|
||||||
uint32_t total_independ_stream_num_{0};
|
uint32_t total_independ_stream_num_{0};
|
||||||
|
|
Loading…
Reference in New Issue