!31173 Stream Assign Refactor

Merge pull request !31173 from jiaorui/refactor-stream
This commit is contained in:
i-robot 2022-03-17 11:39:11 +00:00 committed by Gitee
commit d37ee6e729
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 228 additions and 405 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -51,6 +51,7 @@ constexpr uint32_t kTaskNumPerHcomNode = 300;
constexpr uint32_t kTaskNumPerWorldHcomNode = 350;
constexpr uint32_t kTaskNumPerSameServerHcomNode = 125;
constexpr uint32_t kTaskNumPerHcomSendRecvNode = 15;
constexpr uint32_t kTaskNumPerCommonNode = 3;
constexpr size_t kHcomNum = 2;
constexpr size_t kLastGradHcomOffset = 2;
@ -212,20 +213,6 @@ void AscendStreamAssign::GetMaxStreamTaskNum() {
MS_LOG(INFO) << "AscendStreamAssign::max_task_count_: " << max_task_count_;
}
uint32_t AscendStreamAssign::max_stream_count() {
if (!max_stream_count_) {
GetMaxStreamTaskNum();
}
return max_stream_count_;
}
uint32_t AscendStreamAssign::max_task_count() {
if (!max_task_count_) {
GetMaxStreamTaskNum();
}
return max_task_count_;
}
void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr> &kernels) {
if (kernels.empty()) {
return;
@ -265,7 +252,13 @@ void AscendStreamAssign::AssignStreamForNonTaskSink(const std::vector<CNodePtr>
}
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Status record: start assign stream. graph id: " << graph_ptr->graph_id();
if (graph_ptr->is_dynamic_shape()) {
MS_LOG(WARNING) << "Dynamic shape do not need to assign stream.";
return;
}
MS_LOG(INFO) << "Status record: start assign stream. graph id: " << graph_ptr->graph_id()
<< ", sink node: " << IsTaskSink();
PROF_START(assign_stream);
if (!IsTaskSink()) {
auto kernels = graph_ptr->execution_order();
@ -276,55 +269,48 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
MS_LOG(INFO) << "Status record: end assign stream. graph id: " << graph_ptr->graph_id();
return;
}
if (!graph_ptr->is_dynamic_shape()) {
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
<< ".";
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
<< ".";
Reset();
SetLoopSink();
ReorderIndependentOrders(graph_ptr);
TrailingTimeOptimizationByReorder(graph_ptr);
Reset();
SetLoopSink();
GetMaxStreamTaskNum();
ReorderIndependentOrders(graph_ptr);
TrailingTimeOptimizationByReorder(graph_ptr);
AssignAllNodesStream(graph_ptr);
UpdateAtomicAddrCleanStreamId(graph_ptr);
InsertStreamActive(graph_ptr);
InsertEventForHcomParallel(graph_ptr);
InsertEventForIndependentParallel(graph_ptr);
InsertEventForMicroBatchIndependent(graph_ptr);
GetIndependentMaxTarget(graph_ptr);
InsertCtrlForIndependentParallel(graph_ptr);
AdjustAtomicAddrCleanOrder(graph_ptr);
AssignAllNodesStream(graph_ptr);
UpdateAtomicAddrCleanStreamId(graph_ptr);
InsertStreamActive(graph_ptr);
InsertEventForHcomParallel(graph_ptr);
InsertEventForIndependentParallel(graph_ptr);
InsertEventForMicroBatchIndependent(graph_ptr);
GetIndependentMaxTarget(graph_ptr);
InsertCtrlForIndependentParallel(graph_ptr);
AdjustAtomicAddrCleanOrder(graph_ptr);
GetNeedActiveStreams(graph_ptr);
GetNeedActiveStreams(graph_ptr);
MS_LOG(INFO) << "After finish stream assign and before check resource assign:";
graph_ptr->PrintGraphExecuteOrder();
CheckResourceAssign(graph_ptr);
MS_LOG(INFO) << "After finish stream assign and before check resource assign:";
graph_ptr->PrintGraphExecuteOrder();
CheckResourceAssign(graph_ptr);
#ifdef ENABLE_DUMP_IR
SubModuleId module = SubModuleId::SM_SESSION;
std::string name = "assign_stream." + std::to_string(graph_ptr->graph_id());
const std::vector<CNodePtr> &exec_order = graph_ptr->execution_order();
(void)mindspore::RDR::RecordStreamExecOrder(module, name, exec_order);
SubModuleId module = SubModuleId::SM_SESSION;
std::string name = "assign_stream." + std::to_string(graph_ptr->graph_id());
const std::vector<CNodePtr> &exec_order = graph_ptr->execution_order();
(void)mindspore::RDR::RecordStreamExecOrder(module, name, exec_order);
#endif
SetNodeStreamIDAttr(graph_ptr);
FindStreamRelations(graph_ptr);
PrintStreamRelations();
GetStreamRelations();
PrintStreamGroups();
FindEventRelations(graph_ptr);
}
SetNodeStreamIDAttr(graph_ptr);
FindStreamRelations(graph_ptr);
PrintStreamRelations();
GetStreamRelations();
PrintStreamGroups();
FindEventRelations(graph_ptr);
PROF_END(assign_stream);
MS_LOG(INFO) << "Status record: end assign stream. graph id: " << graph_ptr->graph_id();
}
void AscendStreamAssign::SetLoopSink() {
if (KernelAdjust::NeedLoopSink()) {
loop_sink_ = true;
} else {
loop_sink_ = false;
}
}
void AscendStreamAssign::SetLoopSink() { loop_sink_ = KernelAdjust::NeedLoopSink(); }
// section 1
void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr) {
@ -654,222 +640,150 @@ void AscendStreamAssign::TrailingTimeOptimizationByReorder(const NotNull<KernelG
// section 2
void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
bool exit_independent = false;
bool exit_hcom = false;
std::vector<CNodePtr> common_node_list;
std::vector<CNodePtr> hcom_node_list;
std::vector<CNodePtr> independent_node_list;
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
// node has been assigned stream before
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
continue;
}
if (IsHcom(cur_cnode_ptr)) {
exit_hcom = true;
continue;
}
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
exit_independent = true;
continue;
}
AssignCommonStreamId(cur_cnode_ptr);
}
ClassifyNodeByKernel(graph_ptr, &common_node_list, &hcom_node_list, &independent_node_list);
// Assign Stream for common node
common_stream_ = AssignNodeStreamInOrder(common_node_list);
// Common stream assignment of GetNext-While and EOS is executed in kernel-adjust, so the common_stream_num is
// acquired from resource manager rather than common_stream_.
auto common_stream_num = resource_manager.cur_stream_num();
if (exit_hcom) {
AssignHcom(graph_ptr);
// Assign Stream for hcom node
std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map;
ClassifyNodeByGroupAndGraph(hcom_node_list, &group_graph_nodes_map);
for (const auto &iter_group : group_graph_nodes_map) {
for (const auto &iter_graph : iter_group.second) {
auto stream_set = AssignNodeStreamInOrder(iter_graph.second);
hcom_stream_.insert(stream_set.begin(), stream_set.end());
group_hcom_graph_map_[iter_group.first][iter_graph.first] = stream_set;
}
}
auto hcom_stream_num = resource_manager.cur_stream_num() - common_stream_num;
if (exit_independent) {
AssignIndependent(graph_ptr);
// Assign Stream for independent node
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
ClassifyNodeByGraph(independent_node_list, &graph_nodes_map);
for (auto iter_graph : graph_nodes_map) {
independent_stream_ = AssignNodeStreamInOrder(iter_graph.second);
independent_graph_map_[iter_graph.first] = independent_stream_;
}
auto independent_stream_num = resource_manager.cur_stream_num() - common_stream_num - hcom_stream_num;
auto total_stream_num =
resource_manager.cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum);
resource_manager.cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_.size(), kHcomSecondaryStreamNum);
MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num
<< ", hcom stream number: " << hcom_stream_num << "*" << (kHcomSecondaryStreamNum + 1)
<< ", independent stream number: " << independent_stream_num << ".";
<< ", hcom stream number: " << hcom_stream_.size() << "*" << (kHcomSecondaryStreamNum + 1)
<< ", independent stream number: " << independent_stream_.size() << ".";
if (total_stream_num > max_stream_count()) {
MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << max_stream_count()
if (total_stream_num > max_stream_count_) {
MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << max_stream_count_
<< ", search details information in mindspore's FAQ.";
}
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.cur_stream_num();
}
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
uint32_t cur_common_stream_id = 0;
uint32_t cur_stream_num = resource_manager.cur_stream_num();
if (cur_stream_num == 0) {
cur_common_stream_id = resource_manager.ApplyNewStream();
} else {
cur_common_stream_id = resource_manager.GetCurAllocStreamId();
}
auto it = common_stream_map_.find(cur_common_stream_id);
if (it == common_stream_map_.end()) {
AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
} else {
if (it->second < kMaxCommonNodeNumPerStream) {
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
it->second++;
} else {
cur_common_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
}
}
}
void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map;
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
// node has been assigned stream before
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
void AscendStreamAssign::ClassifyNodeByKernel(const NotNull<KernelGraphPtr> &graph_ptr,
std::vector<CNodePtr> *common_list, std::vector<CNodePtr> *hcom_list,
std::vector<CNodePtr> *independent_list) {
MS_EXCEPTION_IF_NULL(common_list);
MS_EXCEPTION_IF_NULL(hcom_list);
MS_EXCEPTION_IF_NULL(independent_list);
for (auto cur_cnode : graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(cur_cnode);
if (IsHcom(cur_cnode)) {
hcom_list->push_back(cur_cnode);
continue;
}
if (IsHcom(cur_cnode_ptr)) {
auto group_name = GetHcomGroup(cur_cnode_ptr);
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
auto iter = group_graph_nodes_map.find(group_name);
if (iter == group_graph_nodes_map.end()) {
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
group_graph_nodes_map[group_name] = graph_nodes_map;
} else {
auto &graph_nodes_map = iter->second;
auto it = graph_nodes_map.find(hcom_graph_id);
if (it == graph_nodes_map.end()) {
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
} else {
it->second.emplace_back(cur_cnode_ptr);
}
}
if (AnfAlgo::IsIndependentNode(cur_cnode)) {
independent_list->push_back(cur_cnode);
continue;
}
}
MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size();
for (const auto &item : group_graph_nodes_map) {
MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size();
}
for (const auto &diff_group : group_graph_nodes_map) {
// group id:
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map;
for (const auto &item : diff_group.second) {
bool new_graph = true;
auto graph_id = item.first;
hcom_graph_map[graph_id] = {};
for (const auto &hcom_node_ptr : item.second) {
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
hcom_graph_map[graph_id].emplace(assigned_stream_id);
new_graph = false;
}
}
group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
common_list->push_back(cur_cnode);
}
}
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
auto task_num = GetHcomTaskNum(cur_cnode_ptr);
uint32_t cur_hcom_stream_id;
if (new_graph) {
cur_hcom_stream_id = resource_manager.ApplyNewStream();
} else {
cur_hcom_stream_id = resource_manager.GetCurAllocStreamId();
}
auto it = hcom_stream_map_.find(cur_hcom_stream_id);
if (it == hcom_stream_map_.end()) {
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
} else {
if (it->second <= max_task_count() - task_num) {
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
it->second = Uint32tAddWithOverflowCheck(it->second, task_num);
} else {
cur_hcom_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
}
}
return cur_hcom_stream_id;
}
void AscendStreamAssign::AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
void AscendStreamAssign::ClassifyNodeByGroupAndGraph(const std::vector<CNodePtr> hcom_list,
GroupGraphMap *group_graph_map) {
MS_EXCEPTION_IF_NULL(group_graph_map);
for (auto cur_cnode_ptr : hcom_list) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
continue;
if (!IsHcom(cur_cnode_ptr)) {
MS_LOG(EXCEPTION) << "Node is not hcom node, it's " << cur_cnode_ptr->fullname_with_scope();
}
if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
auto it = graph_nodes_map.find(independent_graph_id);
auto group_name = GetHcomGroup(cur_cnode_ptr);
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
auto iter = group_graph_map->find(group_name);
if (iter == group_graph_map->end()) {
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
(*group_graph_map)[group_name] = graph_nodes_map;
} else {
auto &graph_nodes_map = iter->second;
auto it = graph_nodes_map.find(hcom_graph_id);
if (it == graph_nodes_map.end()) {
graph_nodes_map[independent_graph_id] = {cur_cnode_ptr};
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
} else {
it->second.emplace_back(cur_cnode_ptr);
}
}
}
MS_LOG(INFO) << "independent diff graph id size:" << graph_nodes_map.size();
for (const auto &item : graph_nodes_map) {
bool new_graph = true;
auto graph_id = item.first;
independent_graph_map_[graph_id] = {};
for (const auto &independent_node_ptr : item.second) {
auto assigned_stream_id = AssignIndependentStreamId(independent_node_ptr, new_graph);
independent_graph_map_[graph_id].emplace(assigned_stream_id);
new_graph = false;
}
}
MS_LOG(INFO) << "stream nums:" << independent_stream_map_.size();
}
uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
std::set<uint32_t> AscendStreamAssign::AssignNodeStreamInOrder(const std::vector<CNodePtr> node_list) {
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
uint32_t cur_independent_stream_id;
if (new_graph) {
cur_independent_stream_id = resource_manager.ApplyNewStream();
} else {
cur_independent_stream_id = resource_manager.GetCurAllocStreamId();
}
auto it = independent_stream_map_.find(cur_independent_stream_id);
if (it == independent_stream_map_.end()) {
AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
} else {
if (it->second < kMaxCommonNodeNumPerStream) {
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
it->second++;
auto cur_stream_id = resource_manager.ApplyNewStream();
std::map<uint32_t, uint32_t> stream_task_map;
std::set<uint32_t> stream_set;
for (auto cur_cnode_ptr : node_list) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
continue;
}
auto task_num = GetNodeTaskNum(cur_cnode_ptr);
auto it = stream_task_map.find(cur_stream_id);
if (it == stream_task_map.end()) {
AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get());
stream_task_map.emplace(cur_stream_id, task_num);
stream_set.emplace(cur_stream_id);
} else {
cur_independent_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
if (it->second <= max_task_count_ - task_num) {
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
it->second = Uint32tAddWithOverflowCheck(it->second, task_num);
} else {
cur_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get());
stream_task_map.emplace(cur_stream_id, task_num);
stream_set.emplace(cur_stream_id);
}
}
}
if (stream_set.empty()) {
resource_manager.DeleteStream();
}
return stream_set;
}
return cur_independent_stream_id;
void AscendStreamAssign::ClassifyNodeByGraph(const std::vector<CNodePtr> indepent_list,
std::map<uint32_t, std::vector<CNodePtr>> *graph_nodes_map) {
MS_EXCEPTION_IF_NULL(graph_nodes_map);
for (auto cur_cnode_ptr : indepent_list) {
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (!AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
MS_LOG(EXCEPTION) << "Node is not independent node, it's " << cur_cnode_ptr->fullname_with_scope();
}
auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
auto it = graph_nodes_map->find(independent_graph_id);
if (it == graph_nodes_map->end()) {
(*graph_nodes_map)[independent_graph_id] = {cur_cnode_ptr};
} else {
it->second.emplace_back(cur_cnode_ptr);
}
}
}
uint32_t AscendStreamAssign::GetNodeTaskNum(const CNodePtr &cnode) {
return IsHcom(cnode) ? GetHcomTaskNum(cnode) : kTaskNumPerCommonNode;
}
// section 3
@ -1100,7 +1014,8 @@ void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPt
continue;
}
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
bool processed = IsProcessedStream(cur_stream_id);
bool processed = std::any_of(processed_streams_.begin(), processed_streams_.end(),
[cur_stream_id](uint32_t iter_stream) { return iter_stream == cur_stream_id; });
// 1)inner stream assign, need insert active op
if (!processed) {
MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id;
@ -1224,36 +1139,6 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
}
}
bool AscendStreamAssign::CheckStreamSwitch(const CNodePtr &switch_ptr) {
if (!common::AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
return false;
}
auto need_active = common::AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
if (!need_active) {
return false;
}
if (!common::AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, switch_ptr)) {
return false;
}
auto kind = common::AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
if (kind == kEosStreamSwitch || kind == kGetNextStreamSwitch) {
return false;
}
return true;
}
bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id);
if (it != processed_streams_.end()) {
return true;
}
return false;
}
bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(cnode);
@ -1277,18 +1162,6 @@ bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNo
return cnode_out_num == output_index_set.size();
}
vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
vector<CNodePtr>::iterator end) {
while (begin != end) {
if (common::AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
return begin;
}
++begin;
}
return end;
}
// section5
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Start";
@ -1401,7 +1274,8 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
if (target == cnodes.end()) {
if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
// if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
target = FindGraphEnd(it, cnodes.end());
target = std::find_if(
it, cnodes.end(), [](CNodePtr temp_node) { return common::AnfAlgo::HasNodeAttr(kAttrFpBpEnd, temp_node); });
}
if (target == cnodes.end()) {
@ -1811,7 +1685,6 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
graph_ptr->set_execution_order(new_cnodes);
MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.cur_event_num();
MS_LOG(INFO) << "End";
}
void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
@ -1863,14 +1736,9 @@ void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &
uint32_t AscendStreamAssign::GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key) {
auto &exe_orders = graph_ptr->execution_order();
for (uint32_t i = 0; i < exe_orders.size(); i++) {
CNodeKey node_key = exe_orders[i].get();
if (node_key == key) {
return i;
}
}
return UINT32_MAX;
auto result =
std::find_if(exe_orders.begin(), exe_orders.end(), [key](CNodePtr cnode) { return cnode.get() == key; });
return result == exe_orders.end() ? UINT32_MAX : (result - exe_orders.begin());
}
uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
@ -2022,7 +1890,8 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
// check stream assign
if (!streams.empty()) {
if (min_stream != 0) {
MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream;
MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream
<< ", graph id: " << graph_ptr->graph_id();
}
uint32_t assigned_stream_num = resource_manager.cur_stream_num();
if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) {
@ -2086,41 +1955,15 @@ void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_p
// section9
CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
uint32_t stream_id) {
auto send_op = std::make_shared<Primitive>(kSendOpName);
MS_EXCEPTION_IF_NULL(send_op);
auto send_apply = std::make_shared<ValueNode>(send_op);
MS_EXCEPTION_IF_NULL(send_apply);
std::vector<AnfNodePtr> send_input_list = {send_apply};
CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
MS_EXCEPTION_IF_NULL(send_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
send_node_ptr->set_abstract(abstract_none);
auto send_node_ptr = KernelAdjust::GetInstance().CreateSendApplyKernel(graph_ptr, event_id);
AnfAlgo::SetStreamId(stream_id, send_node_ptr.get());
return send_node_ptr;
}
CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
uint32_t stream_id) {
auto recv_op = std::make_shared<Primitive>(kRecvOpName);
MS_EXCEPTION_IF_NULL(recv_op);
auto recv_apply = std::make_shared<ValueNode>(recv_op);
MS_EXCEPTION_IF_NULL(recv_apply);
std::vector<AnfNodePtr> recv_input_list = {recv_apply};
CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
MS_EXCEPTION_IF_NULL(recv_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
auto recv_node_ptr = KernelAdjust::GetInstance().CreateRecvApplyKernel(graph_ptr, event_id);
AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get());
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
recv_node_ptr->set_abstract(abstract_none);
return recv_node_ptr;
}
@ -2176,13 +2019,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
bool AscendStreamAssign::IsTaskSink() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
MS_LOG(INFO) << "Task sink mode is not enable";
return false;
} else {
MS_LOG(INFO) << "Task sink mode is enable";
return true;
}
return ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
}
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
@ -2204,24 +2041,23 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
}
}
void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
MS_EXCEPTION_IF_NULL(streams);
std::transform(hcom_stream_.begin(), hcom_stream_.end(), std::back_inserter(*streams),
[](const uint32_t &item) { return item; });
}
bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) {
MS_EXCEPTION_IF_NULL(apply_kernel);
return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL;
}
void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
MS_EXCEPTION_IF_NULL(streams);
std::transform(hcom_stream_map_.begin(), hcom_stream_map_.end(), std::back_inserter(*streams),
[](const std::pair<uint32_t, uint32_t> &item) { return item.first; });
}
void AscendStreamAssign::Reset() {
independent_stream_activated_ = false;
hcom_stream_activated_ = false;
loop_sink_ = false;
independent_stream_map_.clear();
hcom_stream_map_.clear();
common_stream_map_.clear();
independent_stream_.clear();
hcom_stream_.clear();
processed_streams_.clear();
need_first_active_streams_.clear();
stream_groups_.clear();
@ -2354,6 +2190,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
}
auto cur_cnode = orders[index];
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
auto active_list = common::AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList);
if (kind == kHead) {
uint32_t active_current_stream_id = GetStreamByActivedStream(cur_stream_id);
@ -2434,13 +2271,13 @@ StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGra
continue;
}
auto stream = AnfAlgo::GetStreamId(cnode);
auto it = hcom_stream_map_.find(stream);
if (it != hcom_stream_map_.end()) {
auto it = hcom_stream_.find(stream);
if (it != hcom_stream_.end()) {
continue;
}
it = independent_stream_map_.find(stream);
if (it != independent_stream_map_.end()) {
it = independent_stream_.find(stream);
if (it != independent_stream_.end()) {
continue;
}
@ -2455,13 +2292,13 @@ StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGra
}
auto stream = AnfAlgo::GetStreamId(cnode);
auto it = hcom_stream_map_.find(stream);
if (it != hcom_stream_map_.end()) {
auto it = hcom_stream_.find(stream);
if (it != hcom_stream_.end()) {
continue;
}
it = independent_stream_map_.find(stream);
if (it != independent_stream_map_.end()) {
it = independent_stream_.find(stream);
if (it != independent_stream_.end()) {
continue;
}

View File

@ -43,6 +43,7 @@ using std::unordered_map;
using std::unordered_set;
using std::vector;
using CNodeKey = void *;
using GroupGraphMap = std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>>;
const uint32_t kInvalidStreamId = UINT32_MAX;
const uint32_t kInvalidEventId = UINT32_MAX;
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
@ -57,32 +58,32 @@ class AscendStreamAssign {
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr);
void GetHcomStreams(std::vector<uint32_t> *streams);
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
void GetHcomStreams(std::vector<uint32_t> *streams);
void AssignStreamForNonTaskSink(const std::vector<CNodePtr> &kernels);
const std::vector<std::vector<uint32_t>> &get_stream_group() const { return stream_groups_; }
const std::map<CNodePtr, CNodePtr> &get_event_map() const { return event_map_; }
uint32_t max_stream_count();
uint32_t max_task_count();
private:
AscendStreamAssign() = default;
~AscendStreamAssign() = default;
void GetMaxStreamTaskNum();
void Reset();
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
std::set<uint32_t> AssignNodeStreamInOrder(const std::vector<CNodePtr> node_list);
void ClassifyNodeByKernel(const NotNull<KernelGraphPtr> &graph_ptr, std::vector<CNodePtr> *common_list,
std::vector<CNodePtr> *hcom_list, std::vector<CNodePtr> *independent_list);
void ClassifyNodeByGroupAndGraph(const std::vector<CNodePtr> hcom_list, GroupGraphMap *group_graph_map);
void ClassifyNodeByGraph(const std::vector<CNodePtr> indepent_list,
std::map<uint32_t, std::vector<CNodePtr>> *graph_node_map);
uint32_t GetNodeTaskNum(const CNodePtr &cnode);
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
void AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr);
uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
void AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
@ -92,7 +93,6 @@ class AscendStreamAssign {
const std::set<uint32_t> &independent_streams);
void ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
std::map<uint32_t, std::set<uint32_t>> other_graph);
bool CheckStreamSwitch(const CNodePtr &switch_ptr);
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
@ -134,11 +134,11 @@ class AscendStreamAssign {
bool IsTaskSink();
bool IsHcom(const CNodePtr &cur_cnode_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr);
bool IsProcessedStream(uint32_t stream_id);
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr &node, bool exclude_hcom);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void SetLoopSink();
void GetMaxStreamTaskNum();
void Reset();
// function for memory reuse
void GetStreamRelations();
@ -165,12 +165,10 @@ class AscendStreamAssign {
bool hcom_stream_activated_{false};
bool loop_sink_{false};
// key:stream id, value:node number
std::map<uint32_t, uint32_t> common_stream_map_{};
// key:stream id, value:node number
std::map<uint32_t, uint32_t> independent_stream_map_{};
// key:stream id, value:task number
std::map<uint32_t, uint32_t> hcom_stream_map_{};
std::set<uint32_t> common_stream_{};
std::set<uint32_t> independent_stream_{};
std::set<uint32_t> hcom_stream_{};
std::set<uint32_t> processed_streams_{};
std::vector<uint32_t> need_first_active_streams_{};
@ -189,7 +187,6 @@ class AscendStreamAssign {
std::set<uint32_t> middle_active_streams_{};
// new policy end
bool IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode);
vector<CNodePtr>::iterator FindGraphEnd(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end);
uint32_t max_stream_count_ = 0;
uint32_t max_task_count_ = 0;

View File

@ -44,6 +44,14 @@ class AscendStreamMng {
}
}
void DeleteStream() {
if (!cur_stream_num_) {
MS_LOG(WARNING) << " total stream num is 0, no stream to delete";
} else {
--cur_stream_num_;
}
}
uint32_t cur_stream_num() const { return cur_stream_num_; }
uint32_t GetCurAllocStreamId() {

View File

@ -74,6 +74,20 @@ bool KernelAdjust::NeedLoopSink() {
context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
}
CNodePtr CreateEventApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
std::vector<AnfNodePtr> input_list) {
MS_EXCEPTION_IF_NULL(graph_ptr);
CNodePtr event_node_ptr = graph_ptr->NewCNode(input_list);
MS_EXCEPTION_IF_NULL(event_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), event_node_ptr.get());
common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), event_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
event_node_ptr->set_abstract(abstract_none);
return event_node_ptr;
}
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
uint32_t event_id) {
MS_EXCEPTION_IF_NULL(graph_ptr);
@ -81,17 +95,7 @@ CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::Kern
MS_EXCEPTION_IF_NULL(send_op);
auto send_apply = std::make_shared<ValueNode>(send_op);
MS_EXCEPTION_IF_NULL(send_apply);
std::vector<AnfNodePtr> send_input_list = {send_apply};
CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
MS_EXCEPTION_IF_NULL(send_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
send_node_ptr->set_abstract(abstract_none);
return send_node_ptr;
return CreateEventApplyKernel(graph_ptr, event_id, {send_apply});
}
CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
@ -101,41 +105,22 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern
MS_EXCEPTION_IF_NULL(recv_op);
auto recv_apply = std::make_shared<ValueNode>(recv_op);
MS_EXCEPTION_IF_NULL(recv_apply);
std::vector<AnfNodePtr> recv_input_list = {recv_apply};
CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
MS_EXCEPTION_IF_NULL(recv_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
common::AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
recv_node_ptr->set_abstract(abstract_none);
return recv_node_ptr;
return CreateEventApplyKernel(graph_ptr, event_id, {recv_apply});
}
bool KernelAdjust::ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
const std::vector<CNodePtr> &cnode_list = kernel_graph_ptr->execution_order();
for (const auto &cnode : cnode_list) {
if (common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
return true;
}
}
return false;
return std::any_of(cnode_list.begin(), cnode_list.end(),
[](const CNodePtr &cnode) { return common::AnfAlgo::GetCNodeName(cnode) == kGetNextOpName; });
}
bool KernelAdjust::ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
const auto &exe_orders = kernel_graph_ptr->execution_order();
for (const auto &node : exe_orders) {
if (AnfAlgo::IsIndependentNode(node) && AnfAlgo::GetGraphId(node.get()) == kernel_graph_ptr->graph_id()) {
MS_LOG(INFO) << "graph exit independent node";
return true;
}
}
return false;
return std::any_of(exe_orders.begin(), exe_orders.end(), [&kernel_graph_ptr](const CNodePtr &node) {
return AnfAlgo::IsIndependentNode(node) && AnfAlgo::GetGraphId(node.get()) == kernel_graph_ptr->graph_id();
});
}
void KernelAdjust::InsertIndepentParallel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,

View File

@ -70,6 +70,8 @@ class KernelAdjust {
#endif
static bool NeedLoopSink();
CNodePtr CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
private:
KernelAdjust() = default;
@ -84,8 +86,6 @@ class KernelAdjust {
const AnfNodePtr &specify_para);
CNodePtr CreateAssign(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, const AnfNodePtr &specify_para);
void ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
StreamSwitchKind kind);

View File

@ -24,10 +24,6 @@ namespace ascend {
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {}
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; }
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; }
uint32_t AscendStreamAssign::max_stream_count() { return 1; }
uint32_t AscendStreamAssign::max_task_count() { return 1; }
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { return; }
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; }