From f8208c7c522cf6e303b8aec2dff16a8a748ca90b Mon Sep 17 00:00:00 2001 From: gukecai Date: Thu, 16 Apr 2020 15:46:09 +0800 Subject: [PATCH] Support GetNext Parallel --- .../device/ascend/ascend_kernel_runtime.cc | 9 +- .../device/ascend/ascend_stream_assign.cc | 154 +++++++--- .../device/ascend/ascend_stream_assign.h | 33 +- mindspore/ccsrc/device/kernel_adjust.cc | 288 +++++++++--------- mindspore/ccsrc/device/kernel_adjust.h | 25 +- .../ascend/ascend_backend_optimization.cc | 7 + .../pre_activate/mem_reuse/stream_reuse.cc | 4 +- mindspore/ccsrc/session/ascend_session.cc | 2 +- mindspore/ccsrc/utils/utils.h | 3 + .../tasksink/ascend_stream_assign_stub.cc | 4 +- 10 files changed, 305 insertions(+), 224 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 935e694636f..44cf3f8fa87 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); // the streams' flag not HEAD_STREAM - std::vector wait_active_stream_list = assign_instance.GetWaitStreams(); - std::vector force_copy_stream_list = assign_instance.GetHcomStreams(); + std::vector wait_active_stream_list; + assign_instance.GetWaitStreams(&wait_active_stream_list); + auto force_copy_stream_list = assign_instance.hcom_streams(); MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() - << ", total event num:" << assign_instance.GetTotalEventNum() + << ", total event num:" << assign_instance.total_event_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); + 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 8c4d1f4a8f6..e2cf469cd80 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -25,8 +25,8 @@ #include "session/anf_runtime_algorithm.h" #include "device/kernel_adjust.h" #include "predict/generator/utils/ir_model_util.h" -#include "device/kernel_info.h" #include "pre_activate/common/helper.h" +#include "utils/utils.h" namespace mindspore { namespace device { @@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() { inner_parallel_streams_.clear(); processed_parallel_streams_.clear(); hcom_stream_list_.clear(); + need_first_active_streams_.clear(); } void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { @@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids) { - vector physic_ids; +void AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids, vector *physic_ids) { for (auto &id : logic_ids) { auto it = logic_to_physic_map_.find(id); if (it != logic_to_physic_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; - physic_ids.push_back(it->second); + (*physic_ids).push_back(it->second); } else { MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; } @@ -214,10 +214,9 @@ vector AscendStreamAssign::TransLogicToPhysic(const vector & auto it_independ = logic_to_independent_map_.find(id); if (it_independ != logic_to_independent_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; - physic_ids.push_back(it_independ->second); + (*physic_ids).push_back(it_independ->second); } } - return physic_ids; } void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { @@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { MS_EXCEPTION_IF_NULL(primitive); vector active_logic_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. - vector active_physic_ids = TransLogicToPhysic(active_logic_ids); + vector active_physic_ids; + TransLogicToPhysic(active_logic_ids, &active_physic_ids); ValuePtr active_physic_value = MakeValue>(active_physic_ids); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); } @@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id << "]"; vector logic_ids{true_logic_id}; - vector physic_ids = TransLogicToPhysic(logic_ids); + vector physic_ids; + TransLogicToPhysic(logic_ids, &physic_ids); if (physic_ids.empty()) { MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; } @@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) { return false; } -vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id) { - vector parallel_streams; +void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, + vector *parallel_streams) { for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { auto cur_parallel_streams = inner_parallel_streams_[i]; auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); @@ -347,17 +348,17 @@ vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u << "is same with streamacvite stream id" << stream_acitve_id; continue; } - parallel_streams.emplace_back(cur_parallel_streams[j]); + (*parallel_streams).emplace_back(cur_parallel_streams[j]); } // record processed parallel streams - (void)std::copy(parallel_streams.begin(), parallel_streams.end(), + (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(), std::back_inserter(processed_parallel_streams_)); - return parallel_streams; + return; } } - return vector{cur_stream_id}; + (*parallel_streams).push_back(cur_stream_id); } void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { @@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr active_index_list = GetParallelStream(cur_stream_id, pre_stream_id); + std::vector active_index_list; + GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive" && - AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { + } + // inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8) + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName && + AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { // 2)outter stream assign, update active op update_cnode_list.emplace_back(cur_cnode_ptr); UpdateStreamActive(cur_cnode_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamSwitch") { + } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { // 3)update switch op MS_LOG(INFO) << "Insert active op after switch"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); update_cnode_list.emplace_back(cur_cnode_ptr); update_cnode_list.emplace_back(active_ptr); UpdateStreamSwitch(cur_cnode_ptr, active_ptr); @@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + // key:virutal event id, value:real event id + std::unordered_map event_id_map; + uint32_t event_id; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + event_id = GetValue(primitive->GetAttr(kAttrEventId)); + // before stream assign, send/recv event_id assign from kFirstEventId + if (event_id < kFirstEventId) { + continue; + } + auto it = event_id_map.find(event_id); + if (it == event_id_map.end()) { + event_id_map.insert(std::make_pair(event_id, total_event_num_)); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(total_event_num_), cur_cnode_ptr); + total_event_num_++; + } else { + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(it->second), cur_cnode_ptr); + } + } + } +} + void AscendStreamAssign::UpdateStreamId(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_EXCEPTION_IF_NULL(cur_cnode_ptr); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); if (cur_stream_id < kIndependFirstStreamId) { - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive") { + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) { auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(primitive); vector active_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); @@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_LOG(INFO) << "end"; } +void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); + if (value_ptr == nullptr) { + continue; + } + + auto need_active = GetValue(value_ptr); + if (need_active) { + auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first"; + need_first_active_streams_.push_back(stream_id); + } + } +} + void AscendStreamAssign::AssignStreamNew(const shared_ptr &graph_ptr) { if (IsTaskSink()) { ResetNew(); @@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr InsertSendRecvForHcomParallel(graph_ptr); InsertSendRecvForIndependent(graph_ptr); UpdateStreamId(graph_ptr); + UpdateEventId(graph_ptr); + GetNeedActiveStreams(graph_ptr); MS_LOG(INFO) << "after finish stream assign"; PrintGraphExeOrders(graph_ptr); // Get info for D Model - generator::IRModelUtil::GetInstance().set_event_num(GetTotalEventNum()); - generator::IRModelUtil::GetInstance().set_stream_num(GetTotalCommonStreamNum() + GetTotalIndependStreamNum()); + generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num()); // Init to 1,temporarily generator::IRModelUtil::GetInstance().set_batch_num(1); } @@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto send_op = std::make_shared("Send"); + auto send_op = std::make_shared(kSendOpName); MS_EXCEPTION_IF_NULL(send_op); auto send_apply = std::make_shared(send_op); MS_EXCEPTION_IF_NULL(send_apply); @@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); send_node_ptr->set_abstract(abstract_none); @@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto recv_op = std::make_shared("Recv"); + auto recv_op = std::make_shared(kRecvOpName); MS_EXCEPTION_IF_NULL(recv_op); auto recv_apply = std::make_shared(recv_op); MS_EXCEPTION_IF_NULL(recv_apply); @@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); @@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { return false; } - if (AnfAlgo::GetCNodeName(node_ptr) == "GetNext") { + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { MS_LOG(INFO) << "GetNext should not be independent node"; return false; } @@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() { } } -std::vector AscendStreamAssign::GetWaitStreams() { - vector wait_active_stream_list; +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { if (total_common_stream_num_ == 0) { MS_LOG(INFO) << "total_common_stream_num is zero"; - return wait_active_stream_list; + return; } // common stream:active first common stream MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { - MS_LOG(INFO) << "wait common stream id = " << i; - wait_active_stream_list.push_back(i); + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); + if (it == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait common stream id = " << i; + (*wait_active_stream_list).push_back(i); + } } + // all independ stream id before first physical stream id should be actived auto it = logic_to_independent_map_.find(first_logic_id_); if (it != logic_to_independent_map_.end()) { uint32_t independent_id = it->second; @@ -675,16 +737,14 @@ std::vector AscendStreamAssign::GetWaitStreams() { if (i + total_common_stream_num_ <= max_before_physic) { continue; } - MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; - wait_active_stream_list.push_back(i + total_common_stream_num_); + // all wait streams should not in need_first_active_streams_ + auto iter = + std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_); + if (iter == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; + (*wait_active_stream_list).push_back(i + total_common_stream_num_); + } } - - return wait_active_stream_list; -} - -std::vector AscendStreamAssign::GetHcomStreams() { - MS_LOG(INFO) << "hcom total stream nums:" << hcom_stream_list_.size(); - return hcom_stream_list_; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } @@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr& graph_ptr); void AssignAllNodesStream(const std::shared_ptr& graph_ptr); void ResetNew(); void AssignStreamNew(const std::shared_ptr& graph_ptr); bool IsIndependentNode(const CNodePtr& node_ptr); - const std::unordered_map GetIndependentMap() { return logic_to_independent_map_; } - const std::unordered_map GetPhysicMap() { return logic_to_physic_map_; } - std::vector GetWaitStreams(); - std::vector GetHcomStreams(); - - private: - AscendStreamAssign() = default; - ~AscendStreamAssign() = default; - + const std::unordered_map& logic_to_independent_map() { return logic_to_independent_map_; } + const std::unordered_map& logic_to_physic_map() { return logic_to_physic_map_; } + const std::vector>& inner_parallel_streams() { return inner_parallel_streams_; } + void GetWaitStreams(vector* wait_active_stream_list); + const std::vector& hcom_streams() { return hcom_stream_list_; } CNodePtr CreateSendApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, uint32_t stream_id); CNodePtr CreateRecvApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, uint32_t stream_id); + private: + AscendStreamAssign() = default; + ~AscendStreamAssign() = default; + vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, const CNodePtr& node); bool IsHcom(const CNodePtr& apply_kernel); bool IsProcessed(uint32_t logic_id); - vector TransLogicToPhysic(const vector& logic_ids); + void TransLogicToPhysic(const vector& logic_ids, vector* physic_ids); void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, uint32_t* cur_stream_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id); @@ -88,15 +86,17 @@ class AscendStreamAssign { bool IsTaskSink(); void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); void UpdateStreamId(const std::shared_ptr& graph_ptr); + void UpdateEventId(const std::shared_ptr& graph_ptr); void PrintGraphExeOrders(const std::shared_ptr& graph_ptr); void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); void FindAllReduceParallel(const std::shared_ptr& graph_ptr); bool IsProcessedParallelStream(uint32_t stream_id); - vector GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector* parallel_streams); void InsertSendRecvForIndependent(const std::shared_ptr& graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr& graph_ptr); + void GetNeedActiveStreams(const std::shared_ptr& graph_ptr); uint32_t total_common_stream_num_{0}; uint32_t total_independ_stream_num_{0}; @@ -112,6 +112,7 @@ class AscendStreamAssign { std::vector> inner_parallel_streams_{}; std::vector processed_parallel_streams_{}; std::vector hcom_stream_list_{}; + std::vector need_first_active_streams_{}; // new policy end }; } // namespace ascend diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index c1588d7d53f..b557436db94 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -32,16 +32,8 @@ #include "utils/utils.h" #include "device/ascend/profiling/profiling_manager.h" #include "device/ascend/kernel_select_ascend.h" -#include "device/kernel_info.h" #include "runtime/base.h" - -constexpr auto kLoopCountParamName = "loop_count"; -constexpr auto kIterLoopParamName = "iter_loop"; -constexpr auto kZeroParamName = "zero"; -constexpr auto kOneParamName = "one"; -constexpr auto kStreamSwitch = "StreamSwitch"; -constexpr auto kStreamActive = "StreamActive"; -constexpr auto kAssignAdd = "AssignAdd"; +#include "device/ascend/ascend_stream_assign.h" namespace mindspore { namespace device { using device::ascend::ProfilingUtils; @@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() { ConfigManager::GetInstance().iter_num() > 1); } +uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + auto cnode_ptr_list = kernel_graph_ptr->execution_order(); + CNodePtr cur_cnode_ptr = nullptr; + uint32_t label = kInvalidDistincLabel; + for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); + break; + } + } + + return label; +} + +CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + return send_node_ptr; +} + +CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { return; @@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr } } } - std::vector exec_order; - CNodePtr stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(stream_switch_app); - exec_order.push_back(stream_switch_app); - CNodePtr stream_active_switch_app = CreateStreamActiveSwitchOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); + auto orders = kernel_graph_ptr->execution_order(); + if (orders.empty()) { + MS_LOG(EXCEPTION) << "graph execution order is empty"; + } + uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get()); + + std::vector exec_order; + CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(first_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(kGetNextLabel), first_stream_switch_app); + + CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(second_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_cnode_stream_label), second_stream_switch_app); + // add attr "stream_need_active" + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), second_stream_switch_app); + + CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(first_stream_active_app); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get()); + std::vector first_active_streams = {kFirstStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(first_active_streams), + first_stream_active_app); + + CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(second_stream_active_app); + // specific deal for common ctrl stream policy + uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr); + if (first_common_stream_switch_label == kInvalidDistincLabel) { + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get()); + } else { + AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get()); + } + + std::vector second_active_streams = {kSecondStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(second_active_streams), + second_stream_active_app); CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get()); + + CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get()); + CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get()); + + // reorder graph orders + exec_order.push_back(first_stream_switch_app); + size_t i = 0; + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + break; + } + } + + exec_order.push_back(send); + exec_order.push_back(second_stream_switch_app); + exec_order.push_back(recv); exec_order.push_back(assign_add_one); - auto original_exec_order = kernel_graph_ptr->execution_order(); - (void)std::copy(original_exec_order.begin(), original_exec_order.end(), std::back_inserter(exec_order)); - exec_order.push_back(stream_active_switch_app); + std::vector memcpy_list; + std::vector before_list; + std::vector after_list; + bool first_memcpy_found = false; + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + memcpy_list.emplace_back(cur_cnode); + first_memcpy_found = true; + } else if (first_memcpy_found) { + after_list.emplace_back(cur_cnode); + } else { + before_list.emplace_back(cur_cnode); + } + } + + (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(first_stream_active_app); + (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(second_stream_active_app); kernel_graph_ptr->set_execution_order(exec_order); } @@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(); - auto stream_switch = std::make_shared(kStreamSwitch); + auto stream_switch = std::make_shared(kStreamSwitchOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_switch)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -181,28 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(RT_LESS); ValuePtr cond = MakeValue(condition); AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); - // set attr:true branch graph id ,which is same to stream distinction label - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto first_stream = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_stream), stream_switch_app); // set attr:data_type int data_type = static_cast(RT_SWITCH_INT64); ValuePtr dt = MakeValue(data_type); AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); // set distinction label and graph id - AnfAlgo::SetGraphId(kInvalidGraphId - 1, stream_switch_app.get()); - AnfAlgo::SetStreamDistinctionLabel(kInvalidDistincLabel - 1, stream_switch_app.get()); return stream_switch_app; } -CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { +CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); + auto stream_active_others = std::make_shared(kStreamActiveOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_active_others)); MS_EXCEPTION_IF_NULL(kernel_graph_ptr); @@ -213,57 +327,6 @@ CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_switch = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_switch)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_switch_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_switch_app.get()); - stream_active_switch_app->set_abstract(typeNone_abstract); - // set attr,which stream to active - std::vector active_index_value = {kInvalidDistincLabel - 1}; - auto value = MakeValue>(active_index_value); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, value, stream_active_switch_app); - // set the distinction label of stream active - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - // find the first switch's distinction label - for (auto node : kernel_graph_ptr->execution_order()) { - if (AnfAlgo::GetCNodeName(node) == "StreamSwitch") { - label = AnfAlgo::GetStreamDistinctionLabel(node.get()); - break; - } - } - AnfAlgo::SetStreamDistinctionLabel(label, stream_active_switch_app.get()); - return stream_active_switch_app; -} - -CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_others)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_others_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); - stream_active_others_app->set_abstract(typeNone_abstract); - // set attr - ValuePtr active_target = MakeValue(kValueTargetOther); - AnfAlgo::SetNodeAttr(kAttrActiveTarget, active_target, stream_active_others_app); - return stream_active_others_app; -} - CNodePtr KernelAdjust::CreateStreamAssignAddnOP( const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input) { @@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); // AssignAdd - auto assign_add = std::make_shared(kAssignAdd); + auto assign_add = std::make_shared(kAssignAddOpName); std::vector inputs; inputs.push_back(NewValueNode(assign_add)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); - // set the distinction label of assign add - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetStreamDistinctionLabel(label, assign_add_one.get()); return assign_add_one; } -void KernelAdjust::SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - std::vector index_list; - index_list.clear(); - if (GetValue(active_target) == kValueTargetSwitch) { - index_list.insert(index_list.end(), ctrl_stream_list.begin(), ctrl_stream_list.end()); - } else if (GetValue(active_target) == kValueTargetOther) { - for (uint32_t index : comm_stream_list) { - if (AnfAlgo::GetStreamId(cnode_ptr) == index) { - continue; - } - index_list.emplace_back(index); - } - index_list.insert(index_list.end(), momentum_stream_list.begin(), momentum_stream_list.end()); - } - ValuePtr index_list_value = MakeValue(index_list); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, index_list_value, cnode_ptr); - } - } -} - -void KernelAdjust::SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr switch_cnode_ptr = nullptr; - uint32_t target_stream_id = 0; - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamSwitch) { - switch_cnode_ptr = cnode_ptr; - } - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - if (GetValue(active_target) == kValueTargetOther) { - target_stream_id = AnfAlgo::GetStreamId(cnode_ptr); - } - } - } - if (switch_cnode_ptr != nullptr) { - // set attr:true stream - ValuePtr true_index = MakeValue(target_stream_id); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_cnode_ptr); - MS_LOG(INFO) << "switch to true_index:" << target_stream_id; - } -} - bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index ca01d51e54b..3dced257c17 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -28,10 +28,22 @@ #include "session/session_context.h" #include "ir/meta_tensor.h" #include "device/ascend/profiling/profiling_utils.h" +#include "device/kernel_info.h" using mindspore::device::ascend::ProfilingTraceInfo; using mindspore::device::ascend::ProfilingUtils; namespace mindspore { +constexpr auto kLoopCountParamName = "loop_count"; +constexpr auto kIterLoopParamName = "iter_loop"; +constexpr auto kZeroParamName = "zero"; +constexpr auto kOneParamName = "one"; +constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; + +const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1; +const uint32_t kGetNextLabel = kInvalidDistincLabel - 2; +const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3; +const uint32_t kInvalidEventId = UINT32_MAX; +const uint32_t kFirstEventId = kInvalidEventId / 2; namespace device { class KernelAdjust { public: @@ -41,26 +53,23 @@ class KernelAdjust { } void Reorder(const std::shared_ptr &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - void SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list); - void SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr); bool StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr); void Profiling(NotNull kernel_graph_ptr); static bool NeedInsertSwitch(); - CNodePtr CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr); + CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); private: KernelAdjust() = default; ~KernelAdjust() = default; + + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr); void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, std::map *switch_loop_input); CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); - CNodePtr CreateStreamActiveSwitchOp(const std::shared_ptr &kernel_graph_ptr); - CNodePtr CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr); CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 6c245d7548c..0de609f4413 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -62,6 +62,7 @@ #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/ir_fission/addn_fission.h" +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -187,6 +188,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } + + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc index d1409cdedda..77f6f96cec1 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace memreuse { void StreamReuse::SetStreamReuseResource() { #ifdef ENABLE_D - auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().GetPhysicMap(); - auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().GetIndependentMap(); + auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map(); + auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map(); MS_LOG(INFO) << "stream mem reuse for Davici"; if (!logic_independent_map.empty() && !logic_physic_map.empty()) { set_logic_physic_map(logic_physic_map); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index ad6c58bc939..11ae3da6f7a 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kInvalidDistincLabel - 1); + InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); } break; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index eac901b74de..eac1b862739 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -114,6 +114,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; +constexpr auto kStreamSwitchOpName = "StreamSwitch"; +constexpr auto kStreamActiveOpName = "StreamActive"; +constexpr auto kAssignAddOpName = "AssignAdd"; constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; constexpr auto kReluV2OpName = "ReluV2"; diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index e0b5ab0d618..9c4fe2539db 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } -std::vector AscendStreamAssign::GetWaitStreams() { return vector(); } - -std::vector AscendStreamAssign::GetHcomStreams() { return vector(); } +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; } namespace tasksink { bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list,