forked from mindspore-Ecosystem/mindspore
!4813 parallel control stream
Merge pull request !4813 from gukecai/parallel-ctrl
This commit is contained in:
commit
da51877546
|
@ -211,8 +211,11 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
|
||||||
start_index = end_index + 1;
|
start_index = end_index + 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
|
auto graph_id = kernel_graph->graph_id();
|
||||||
AnfNodePtr new_communication_op =
|
AnfNodePtr new_communication_op =
|
||||||
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
|
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
|
||||||
|
AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
|
||||||
// replace old communication op with new communication op
|
// replace old communication op with new communication op
|
||||||
for (auto idx = start_index; idx <= end_index; ++idx) {
|
for (auto idx = start_index; idx <= end_index; ++idx) {
|
||||||
std::vector<AnfNodePtr> tuple_getitem_input;
|
std::vector<AnfNodePtr> tuple_getitem_input;
|
||||||
|
|
|
@ -36,6 +36,7 @@ const uint32_t kCommonMaxTask = 350;
|
||||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
if (IsTaskSink()) {
|
if (IsTaskSink()) {
|
||||||
Reset();
|
Reset();
|
||||||
|
SetLoopSink();
|
||||||
ReorderIndependentOrders(graph_ptr);
|
ReorderIndependentOrders(graph_ptr);
|
||||||
AssignAllNodesStream(graph_ptr);
|
AssignAllNodesStream(graph_ptr);
|
||||||
UpdateAtomicAddrCleanStreamId(graph_ptr);
|
UpdateAtomicAddrCleanStreamId(graph_ptr);
|
||||||
|
@ -46,9 +47,9 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
||||||
InsertCtrlForIndependentParallel(graph_ptr);
|
InsertCtrlForIndependentParallel(graph_ptr);
|
||||||
|
|
||||||
GetNeedActiveStreams(graph_ptr);
|
GetNeedActiveStreams(graph_ptr);
|
||||||
graph_ptr->PrintGraphExecuteOrder();
|
|
||||||
CheckResourceAssign(graph_ptr);
|
CheckResourceAssign(graph_ptr);
|
||||||
MS_LOG(INFO) << "After finish stream assign";
|
MS_LOG(INFO) << "After finish stream assign";
|
||||||
|
graph_ptr->PrintGraphExecuteOrder();
|
||||||
|
|
||||||
FindStreamRelations(graph_ptr);
|
FindStreamRelations(graph_ptr);
|
||||||
PrintStreamRelations();
|
PrintStreamRelations();
|
||||||
|
@ -58,6 +59,14 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::SetLoopSink() {
|
||||||
|
if (KernelAdjust::NeedInsertSwitch()) {
|
||||||
|
loop_sink_ = true;
|
||||||
|
} else {
|
||||||
|
loop_sink_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// section 1
|
// section 1
|
||||||
void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
std::vector<CNodePtr> exe_orders;
|
std::vector<CNodePtr> exe_orders;
|
||||||
|
@ -146,7 +155,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
||||||
MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num();
|
MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num();
|
||||||
|
|
||||||
if (exit_hcom) {
|
if (exit_hcom) {
|
||||||
uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream();
|
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
// node has been assigned stream before
|
// node has been assigned stream before
|
||||||
|
@ -155,28 +164,63 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsHcom(cur_cnode_ptr)) {
|
if (IsHcom(cur_cnode_ptr)) {
|
||||||
AssignHcomStreamId(cur_cnode_ptr);
|
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
|
||||||
|
auto it = graph_nodes_map.find(hcom_graph_id);
|
||||||
|
if (it == graph_nodes_map.end()) {
|
||||||
|
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
|
||||||
|
} else {
|
||||||
|
it->second.emplace_back(cur_cnode_ptr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size();
|
MS_LOG(INFO) << "hcom diff graph id size:" << graph_nodes_map.size();
|
||||||
|
for (const auto &item : graph_nodes_map) {
|
||||||
|
bool new_graph = true;
|
||||||
|
auto graph_id = item.first;
|
||||||
|
hcom_graph_map_[graph_id] = {};
|
||||||
|
for (const auto &hcom_node_ptr : item.second) {
|
||||||
|
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
|
||||||
|
hcom_graph_map_[graph_id].emplace(assigned_stream_id);
|
||||||
|
new_graph = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "hcom stream nums : " << hcom_stream_map_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (exit_independent) {
|
if (exit_independent) {
|
||||||
uint32_t first_independ = resource_manager.ApplyNewStream();
|
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
|
||||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
|
||||||
AssignIndependentStreamId(cur_cnode_ptr);
|
auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
|
||||||
|
auto it = graph_nodes_map.find(independent_graph_id);
|
||||||
|
if (it == graph_nodes_map.end()) {
|
||||||
|
graph_nodes_map[independent_graph_id] = {cur_cnode_ptr};
|
||||||
|
} else {
|
||||||
|
it->second.emplace_back(cur_cnode_ptr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size();
|
|
||||||
|
MS_LOG(INFO) << "independent diff graph id size:" << graph_nodes_map.size();
|
||||||
|
for (const auto &item : graph_nodes_map) {
|
||||||
|
bool new_graph = true;
|
||||||
|
auto graph_id = item.first;
|
||||||
|
independent_graph_map_[graph_id] = {};
|
||||||
|
for (const auto &independent_node_ptr : item.second) {
|
||||||
|
auto assigned_stream_id = AssignIndependentStreamId(independent_node_ptr, new_graph);
|
||||||
|
independent_graph_map_[graph_id].emplace(assigned_stream_id);
|
||||||
|
new_graph = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "stream nums:" << independent_stream_map_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
|
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
|
||||||
}
|
} // namespace ascend
|
||||||
|
|
||||||
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
|
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
|
||||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
|
@ -205,10 +249,15 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) {
|
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||||
uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId();
|
uint32_t cur_hcom_stream_id;
|
||||||
|
if (new_graph) {
|
||||||
|
cur_hcom_stream_id = resource_manager.ApplyNewStream();
|
||||||
|
} else {
|
||||||
|
cur_hcom_stream_id = resource_manager.GetCurAllocStreamId();
|
||||||
|
}
|
||||||
auto it = hcom_stream_map_.find(cur_hcom_stream_id);
|
auto it = hcom_stream_map_.find(cur_hcom_stream_id);
|
||||||
if (it == hcom_stream_map_.end()) {
|
if (it == hcom_stream_map_.end()) {
|
||||||
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
|
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
|
||||||
|
@ -223,26 +272,34 @@ void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) {
|
||||||
hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1));
|
hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return cur_hcom_stream_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) {
|
uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||||
uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId();
|
uint32_t cur_independent_stream_id;
|
||||||
auto it = independent_stream_map_.find(cur_independent_id);
|
if (new_graph) {
|
||||||
|
cur_independent_stream_id = resource_manager.ApplyNewStream();
|
||||||
|
} else {
|
||||||
|
cur_independent_stream_id = resource_manager.GetCurAllocStreamId();
|
||||||
|
}
|
||||||
|
auto it = independent_stream_map_.find(cur_independent_stream_id);
|
||||||
if (it == independent_stream_map_.end()) {
|
if (it == independent_stream_map_.end()) {
|
||||||
AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get());
|
AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
|
||||||
independent_stream_map_.insert(std::make_pair(cur_independent_id, 1));
|
independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
|
||||||
} else {
|
} else {
|
||||||
if (it->second < kCommonMaxTask) {
|
if (it->second < kCommonMaxTask) {
|
||||||
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
||||||
it->second++;
|
it->second++;
|
||||||
} else {
|
} else {
|
||||||
cur_independent_id = resource_manager.ApplyNewStream();
|
cur_independent_stream_id = resource_manager.ApplyNewStream();
|
||||||
AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get());
|
AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
|
||||||
independent_stream_map_.insert(std::make_pair(cur_independent_id, 1));
|
independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return cur_independent_stream_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
// section 3:
|
// section 3:
|
||||||
|
@ -262,6 +319,182 @@ void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraph
|
||||||
|
|
||||||
// section 4
|
// section 4
|
||||||
void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
|
InsertStreamActiveForCommon(graph_ptr);
|
||||||
|
InsertStreamActiveForIndependent(graph_ptr);
|
||||||
|
InsertStreamActiveForParallel(graph_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
|
if (hcom_graph_map_.empty() && independent_graph_map_.empty()) {
|
||||||
|
MS_LOG(INFO) << "Hcom and independent is empty";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto root_graph_id = graph_ptr->graph_id();
|
||||||
|
if (root_graph_id == kInvalidGraphId) {
|
||||||
|
MS_LOG(INFO) << "Root graph id is invalid";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Hcom grpah map size:" << hcom_graph_map_.size();
|
||||||
|
std::map<uint32_t, std::set<uint32_t>> other_graph;
|
||||||
|
for (const auto &item : hcom_graph_map_) {
|
||||||
|
MS_LOG(INFO) << "Graph id:" << item.first;
|
||||||
|
if (item.first == root_graph_id) {
|
||||||
|
if (loop_sink_) {
|
||||||
|
ActiveRootGraphHcom(graph_ptr, item.second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
other_graph[item.first] = item.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size();
|
||||||
|
for (const auto &item : independent_graph_map_) {
|
||||||
|
MS_LOG(DEBUG) << "Graph id:" << item.first;
|
||||||
|
if (item.first == root_graph_id) {
|
||||||
|
if (loop_sink_) {
|
||||||
|
ActiveRootGraphIndependent(graph_ptr, item.second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto it = other_graph.find(item.first);
|
||||||
|
if (it == other_graph.end()) {
|
||||||
|
other_graph[item.first] = item.second;
|
||||||
|
} else {
|
||||||
|
for (const auto &stream : item.second) {
|
||||||
|
it->second.emplace(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ActiveOtherGraphParallel(graph_ptr, other_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||||
|
std::map<uint32_t, std::set<uint32_t>> other_graph) {
|
||||||
|
MS_LOG(INFO) << "Other graph size:" << other_graph.size();
|
||||||
|
if (other_graph.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto root_graph_id = graph_ptr->graph_id();
|
||||||
|
|
||||||
|
std::vector<CNodePtr> update_stream_list;
|
||||||
|
auto exe_order = graph_ptr->execution_order();
|
||||||
|
for (size_t i = 0; i < exe_order.size(); i++) {
|
||||||
|
auto cur_cnode_ptr = exe_order[i];
|
||||||
|
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
|
||||||
|
if (cur_graph_id == root_graph_id) {
|
||||||
|
update_stream_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = other_graph.find(cur_graph_id);
|
||||||
|
if (it == other_graph.end()) {
|
||||||
|
update_stream_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||||
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||||
|
// 1.set stream id
|
||||||
|
AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
|
||||||
|
// 2.set active stream ids
|
||||||
|
std::vector<uint32_t> active_index_list;
|
||||||
|
std::copy(it->second.begin(), it->second.end(), std::back_inserter(active_index_list));
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
|
||||||
|
|
||||||
|
// find position for insert streamactive
|
||||||
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kLabelSetOpName) {
|
||||||
|
update_stream_list.emplace_back(cur_cnode_ptr);
|
||||||
|
update_stream_list.emplace_back(active_ptr);
|
||||||
|
} else {
|
||||||
|
update_stream_list.emplace_back(active_ptr);
|
||||||
|
update_stream_list.emplace_back(cur_cnode_ptr);
|
||||||
|
}
|
||||||
|
other_graph.erase(it);
|
||||||
|
}
|
||||||
|
graph_ptr->set_execution_order(update_stream_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||||
|
const std::set<uint32_t> &hcom_streams) {
|
||||||
|
MS_LOG(INFO) << "Active root graph hcom start";
|
||||||
|
std::vector<CNodePtr> update_cnode_list;
|
||||||
|
auto exe_orders = graph_ptr->execution_order();
|
||||||
|
for (size_t i = 0; i < exe_orders.size(); i++) {
|
||||||
|
CNodePtr cur_cnode_ptr = exe_orders[i];
|
||||||
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
|
||||||
|
if (kind != kFpBpStreamSwitch) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
|
||||||
|
MS_LOG(INFO) << "FpBpStreamswtich stream id:" << AnfAlgo::GetStreamId(cur_cnode_ptr)
|
||||||
|
<< "; true branch stream id:" << true_stream_id;
|
||||||
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||||
|
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
|
||||||
|
vector<uint32_t> active_ids;
|
||||||
|
// active hcom stream
|
||||||
|
std::copy(hcom_streams.begin(), hcom_streams.end(), std::back_inserter(active_ids));
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
update_cnode_list.emplace_back(active_ptr);
|
||||||
|
std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
hcom_stream_activated_ = true;
|
||||||
|
graph_ptr->set_execution_order(update_cnode_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||||
|
std::set<uint32_t> independent_streams) {
|
||||||
|
MS_LOG(DEBUG) << "Start active root graph independent";
|
||||||
|
std::vector<CNodePtr> update_cnode_list;
|
||||||
|
auto exe_orders = graph_ptr->execution_order();
|
||||||
|
for (size_t i = 0; i < exe_orders.size(); i++) {
|
||||||
|
CNodePtr cur_cnode_ptr = exe_orders[i];
|
||||||
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
|
||||||
|
if (kind != kIndependentStreamSwitch) {
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// first independetn stream id is minimum and order by std map;
|
||||||
|
auto first_independent_stream = *(independent_streams.begin());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), cur_cnode_ptr);
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
independent_stream_activated_ = true;
|
||||||
|
graph_ptr->set_execution_order(update_cnode_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
MS_LOG(INFO) << "Start";
|
MS_LOG(INFO) << "Start";
|
||||||
GetProcessedStream(graph_ptr);
|
GetProcessedStream(graph_ptr);
|
||||||
std::vector<CNodePtr> update_cnode_list;
|
std::vector<CNodePtr> update_cnode_list;
|
||||||
|
@ -298,7 +531,8 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
||||||
MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
|
MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
|
||||||
UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list);
|
// UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list);
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
} else {
|
} else {
|
||||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
}
|
}
|
||||||
|
@ -308,6 +542,70 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
|
||||||
pre_cnode_ptr = cur_cnode_ptr;
|
pre_cnode_ptr = cur_cnode_ptr;
|
||||||
}
|
}
|
||||||
graph_ptr->set_execution_order(update_cnode_list);
|
graph_ptr->set_execution_order(update_cnode_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
|
auto root_graph_id = graph_ptr->graph_id();
|
||||||
|
if (root_graph_id == kInvalidGraphId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::set<uint32_t> independent_streams;
|
||||||
|
for (const auto &item : independent_graph_map_) {
|
||||||
|
if (item.first == root_graph_id) {
|
||||||
|
independent_streams = item.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (independent_streams.size() <= 1) {
|
||||||
|
MS_LOG(INFO) << "Root graph independent stream size is not more than one, no need insert active";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<CNodePtr> update_cnode_list;
|
||||||
|
auto exe_orders = graph_ptr->execution_order();
|
||||||
|
|
||||||
|
// first independent is been actived, active other independent stream
|
||||||
|
std::vector<uint32_t> streams;
|
||||||
|
std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
|
||||||
|
std::sort(streams.begin(), streams.end());
|
||||||
|
uint32_t node_num = 0;
|
||||||
|
uint32_t cur_stream_id = kInvalidStreamId;
|
||||||
|
for (size_t i = 0; i < exe_orders.size(); i++) {
|
||||||
|
auto cur_cnode_ptr = exe_orders[i];
|
||||||
|
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||||
|
bool flag = AnfAlgo::IsIndependentNode(cur_cnode_ptr);
|
||||||
|
if (!flag) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
|
||||||
|
if (graph_id != root_graph_id) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
node_num++;
|
||||||
|
cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||||
|
auto it = std::find(streams.begin(), streams.end(), cur_stream_id);
|
||||||
|
if (it == streams.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can't find independent stream id:" << cur_stream_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (it == streams.end() - 1) {
|
||||||
|
std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
if (node_num == kCommonMaxTask) {
|
||||||
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||||
|
// 1.set stream id
|
||||||
|
AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
|
||||||
|
// 2.set active stream ids
|
||||||
|
std::vector<uint32_t> active_index_list{*(it + 1)};
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
|
||||||
|
update_cnode_list.emplace_back(active_ptr);
|
||||||
|
node_num = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graph_ptr->set_execution_order(update_cnode_list);
|
||||||
MS_LOG(INFO) << "End";
|
MS_LOG(INFO) << "End";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,7 +671,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "independent stream switch exit, but independent stream is empty";
|
MS_LOG(ERROR) << "Independent stream switch exit, but independent stream is empty";
|
||||||
}
|
}
|
||||||
|
|
||||||
// update processed stream
|
// update processed stream
|
||||||
|
@ -472,6 +770,77 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
||||||
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
|
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||||
|
const CNodePtr &cur_cnode_ptr) {
|
||||||
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
|
auto &inputs = cur_cnode_ptr->inputs();
|
||||||
|
auto it_pos = cnode_ptr_list.begin();
|
||||||
|
for (size_t i = 1; i < inputs.size(); i++) {
|
||||||
|
if (inputs[i]->isa<CNode>()) {
|
||||||
|
auto cnode = inputs[i]->cast<CNodePtr>();
|
||||||
|
while (opt::IsNopNode(cnode)) {
|
||||||
|
cnode = cnode->inputs()[1]->cast<CNodePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode);
|
||||||
|
if (it != cnode_ptr_list.end()) {
|
||||||
|
it_pos = it;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (it_pos == cnode_ptr_list.begin() && *it_pos != inputs[1]) {
|
||||||
|
MS_LOG(EXCEPTION) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found";
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "The las input of node:" << cur_cnode_ptr->DebugString() << " is:" << (*it_pos)->fullname_with_scope()
|
||||||
|
<< "; name:" << (*it_pos)->DebugString();
|
||||||
|
return *it_pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
// after memory reuse is correct, use this function
|
||||||
|
void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||||
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
|
vector<CNodePtr> cnodes;
|
||||||
|
CNodePtr cur_cnode_ptr = nullptr;
|
||||||
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||||
|
cur_cnode_ptr = cnode_ptr_list[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||||
|
if (i == 0) {
|
||||||
|
cnodes.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!IsHcom(cur_cnode_ptr)) {
|
||||||
|
cnodes.emplace_back(cur_cnode_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the input which located in the lastr exe orders
|
||||||
|
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
|
||||||
|
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode);
|
||||||
|
if (it == cnodes.end()) {
|
||||||
|
MS_LOG(ERROR) << "hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
|
||||||
|
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes";
|
||||||
|
} else {
|
||||||
|
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||||
|
auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode);
|
||||||
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id);
|
||||||
|
cnodes.insert(it + 1, send);
|
||||||
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||||
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
|
||||||
|
cnodes.emplace_back(recv);
|
||||||
|
cnodes.emplace_back(cur_cnode_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_ptr->set_execution_order(cnodes);
|
||||||
|
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
|
||||||
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||||
|
@ -641,7 +1010,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
||||||
while (it != cnodes.end()) {
|
while (it != cnodes.end()) {
|
||||||
MS_EXCEPTION_IF_NULL(*it);
|
MS_EXCEPTION_IF_NULL(*it);
|
||||||
if (AnfAlgo::IsIndependentNode(*it)) {
|
if (AnfAlgo::IsIndependentNode(*it)) {
|
||||||
MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]";
|
MS_LOG(DEBUG) << "Deal independent op[" << (*it)->DebugString() << "]";
|
||||||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||||
|
|
||||||
|
@ -690,7 +1059,7 @@ void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &
|
||||||
for (size_t k = 1; k < new_inputs.size(); k++) {
|
for (size_t k = 1; k < new_inputs.size(); k++) {
|
||||||
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0);
|
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0);
|
||||||
if (key == new_real_input.first.get()) {
|
if (key == new_real_input.first.get()) {
|
||||||
MS_LOG(INFO) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
MS_LOG(DEBUG) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
||||||
independent_targets_.emplace(target_node.get());
|
independent_targets_.emplace(target_node.get());
|
||||||
flag = true;
|
flag = true;
|
||||||
break;
|
break;
|
||||||
|
@ -699,7 +1068,7 @@ void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &
|
||||||
} else {
|
} else {
|
||||||
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
||||||
if (key == real_input.first.get()) {
|
if (key == real_input.first.get()) {
|
||||||
MS_LOG(INFO) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
MS_LOG(DEBUG) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node);
|
||||||
independent_targets_.emplace(target_node.get());
|
independent_targets_.emplace(target_node.get());
|
||||||
flag = true;
|
flag = true;
|
||||||
}
|
}
|
||||||
|
@ -772,7 +1141,7 @@ void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNull<KernelGr
|
||||||
auto max_index = GetMaxIndexTarget(graph_ptr);
|
auto max_index = GetMaxIndexTarget(graph_ptr);
|
||||||
auto &exe_orders = graph_ptr->execution_order();
|
auto &exe_orders = graph_ptr->execution_order();
|
||||||
if (max_index >= exe_orders.size()) {
|
if (max_index >= exe_orders.size()) {
|
||||||
MS_LOG(EXCEPTION) << "max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size();
|
MS_LOG(EXCEPTION) << "Max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]);
|
auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]);
|
||||||
|
@ -813,16 +1182,19 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2)independent stream:if has not been activate, push to need active vector
|
// 2)independent stream:if has not been activate, push to need active vector
|
||||||
|
auto root_graph_id = graph_ptr->graph_id();
|
||||||
if (!independent_stream_activated_) {
|
if (!independent_stream_activated_) {
|
||||||
for (auto &item : independent_stream_map_) {
|
auto it = independent_graph_map_.find(root_graph_id);
|
||||||
need_first_active_streams_.emplace_back(item.first);
|
if (it != independent_graph_map_.end()) {
|
||||||
|
need_first_active_streams_.push_back(*(it->second.begin()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3)hcom stream:if has not been activate, push to need active vector
|
// 3)hcom stream:if has not been activate, push to need active vector
|
||||||
if (!hcom_stream_activated_) {
|
if (!hcom_stream_activated_) {
|
||||||
for (auto &item : hcom_stream_map_) {
|
auto it = hcom_graph_map_.find(root_graph_id);
|
||||||
need_first_active_streams_.emplace_back(item.first);
|
if (it != hcom_graph_map_.end()) {
|
||||||
|
std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -831,6 +1203,10 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
||||||
if (it == need_first_active_streams_.end()) {
|
if (it == need_first_active_streams_.end()) {
|
||||||
need_first_active_streams_.emplace_back(0);
|
need_first_active_streams_.emplace_back(0);
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "Finally, need active first stream include:";
|
||||||
|
for (const auto &item : need_first_active_streams_) {
|
||||||
|
MS_LOG(INFO) << "stream id:" << item;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// section8
|
// section8
|
||||||
|
@ -977,14 +1353,14 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
|
||||||
for (size_t j = 1; j < new_inputs.size(); j++) {
|
for (size_t j = 1; j < new_inputs.size(); j++) {
|
||||||
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
|
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
|
||||||
if (node == new_real_input.first) {
|
if (node == new_real_input.first) {
|
||||||
MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]";
|
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
|
||||||
return begin;
|
return begin;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
||||||
if (node == real_input.first) {
|
if (node == real_input.first) {
|
||||||
MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]";
|
MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]";
|
||||||
return begin;
|
return begin;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1040,6 +1416,7 @@ void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
|
||||||
void AscendStreamAssign::Reset() {
|
void AscendStreamAssign::Reset() {
|
||||||
independent_stream_activated_ = false;
|
independent_stream_activated_ = false;
|
||||||
hcom_stream_activated_ = false;
|
hcom_stream_activated_ = false;
|
||||||
|
loop_sink_ = false;
|
||||||
independent_stream_map_.clear();
|
independent_stream_map_.clear();
|
||||||
hcom_stream_map_.clear();
|
hcom_stream_map_.clear();
|
||||||
common_stream_map_.clear();
|
common_stream_map_.clear();
|
||||||
|
@ -1049,6 +1426,9 @@ void AscendStreamAssign::Reset() {
|
||||||
stream_relations_.clear();
|
stream_relations_.clear();
|
||||||
event_map_.clear();
|
event_map_.clear();
|
||||||
independent_targets_.clear();
|
independent_targets_.clear();
|
||||||
|
independent_graph_map_.clear();
|
||||||
|
hcom_graph_map_.clear();
|
||||||
|
middle_active_streams_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// section 10
|
// section 10
|
||||||
|
@ -1101,7 +1481,12 @@ void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendStreamAssign::GetStreamRelations() {
|
void AscendStreamAssign::GetStreamRelations() {
|
||||||
for (const auto &start : need_first_active_streams_) {
|
auto starts = middle_active_streams_;
|
||||||
|
for (const auto &stream : need_first_active_streams_) {
|
||||||
|
starts.emplace(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &start : starts) {
|
||||||
vector<uint32_t> group{start};
|
vector<uint32_t> group{start};
|
||||||
DFS(start, &group);
|
DFS(start, &group);
|
||||||
}
|
}
|
||||||
|
@ -1188,7 +1573,8 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
|
||||||
if (stream <= cur_stream_id) {
|
if (stream <= cur_stream_id) {
|
||||||
MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
|
MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now";
|
MS_LOG(INFO) << "MIDDLE StreamActive :" << cur_stream_id << ", active target stream:" << stream;
|
||||||
|
middle_active_streams_.emplace(stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,11 +123,18 @@ class AscendStreamAssign {
|
||||||
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
|
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
|
||||||
void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr);
|
uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
|
||||||
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr);
|
uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
|
||||||
void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
|
void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
|
void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
void InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
void InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
void InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
void ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr, const std::set<uint32_t> &hcom_streams);
|
||||||
|
void ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr, std::set<uint32_t> independent_streams);
|
||||||
|
void ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||||
|
std::map<uint32_t, std::set<uint32_t>> other_graph);
|
||||||
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
|
||||||
vector<CNodePtr> *orders);
|
vector<CNodePtr> *orders);
|
||||||
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
@ -135,9 +142,11 @@ class AscendStreamAssign {
|
||||||
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
|
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
|
||||||
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
|
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
|
||||||
|
CNodePtr GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
|
||||||
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
|
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
|
||||||
|
|
||||||
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||||
|
@ -155,6 +164,7 @@ class AscendStreamAssign {
|
||||||
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
||||||
const CNodePtr &node);
|
const CNodePtr &node);
|
||||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||||
|
void SetLoopSink();
|
||||||
|
|
||||||
// function for memory resue
|
// function for memory resue
|
||||||
void GetStreamRelations();
|
void GetStreamRelations();
|
||||||
|
@ -172,17 +182,23 @@ class AscendStreamAssign {
|
||||||
|
|
||||||
bool independent_stream_activated_{false};
|
bool independent_stream_activated_{false};
|
||||||
bool hcom_stream_activated_{false};
|
bool hcom_stream_activated_{false};
|
||||||
|
bool loop_sink_{false};
|
||||||
|
// key:stream id, value:task nums;
|
||||||
std::map<uint32_t, uint32_t> independent_stream_map_{};
|
std::map<uint32_t, uint32_t> independent_stream_map_{};
|
||||||
std::map<uint32_t, uint32_t> hcom_stream_map_{};
|
std::map<uint32_t, uint32_t> hcom_stream_map_{};
|
||||||
std::map<uint32_t, uint32_t> common_stream_map_{};
|
std::map<uint32_t, uint32_t> common_stream_map_{};
|
||||||
std::set<uint32_t> processed_streams_{};
|
std::set<uint32_t> processed_streams_{};
|
||||||
std::vector<uint32_t> need_first_active_streams_{};
|
std::vector<uint32_t> need_first_active_streams_{};
|
||||||
std::set<CNodeKey> independent_targets_;
|
std::set<CNodeKey> independent_targets_;
|
||||||
|
// key:graph id, value:stream set
|
||||||
|
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map_;
|
||||||
|
std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;
|
||||||
|
|
||||||
// attr for memory copy reuse
|
// attr for memory copy reuse
|
||||||
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
|
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
|
||||||
std::vector<std::vector<uint32_t>> stream_groups_{};
|
std::vector<std::vector<uint32_t>> stream_groups_{};
|
||||||
std::map<CNodePtr, CNodePtr> event_map_;
|
std::map<CNodePtr, CNodePtr> event_map_{};
|
||||||
|
std::set<uint32_t> middle_active_streams_{};
|
||||||
// new policy end
|
// new policy end
|
||||||
};
|
};
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
|
|
Loading…
Reference in New Issue