diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 1804a23e94f..55d0f285b1e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -28,6 +28,7 @@ #include "utils/mpi/mpi_config.h" #include "common/trans.h" #include "runtime/rt.h" +#include "runtime/device/ascend/ascend_stream_manager.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "runtime/device/ascend/ge_runtime/model_runner.h" #include "runtime/device/ascend/tasksink/task_generator.h" @@ -487,20 +488,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph &graph) { return true; } AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list; assign_instance.GetWaitStreams(&wait_active_stream_list); std::vector force_copy_stream_list; assign_instance.GetHcomStreams(&force_copy_stream_list); - MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() - << ", total event num:" << resource_manager.get_cur_event_num() - << ", total label num:" << graph.label_num() + MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.cur_stream_num() + << ", total event num:" << resource_manager.cur_event_num() << ", total label num:" << graph.label_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); auto model = std::make_shared( task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0, - resource_manager.get_cur_stream_num(), graph.label_num(), resource_manager.get_cur_event_num(), 0); + resource_manager.cur_stream_num(), graph.label_num(), resource_manager.cur_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph.graph_id(), model)); if (!ret.second) { MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index b54e9677871..6f96afc6d83 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -26,6 +26,7 @@ #include "frontend/parallel/device_manager.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_adjust.h" +#include "runtime/device/ascend/ascend_stream_manager.h" #include "backend/optimizer/common/helper.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "utils/utils.h" @@ -589,7 +590,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra auto cnode_ptr_list = graph_ptr->execution_order(); bool exit_independent = false; bool exit_hcom = false; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); @@ -611,19 +612,19 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra AssignCommonStreamId(cur_cnode_ptr); } - auto common_stream_num = resource_manager.get_cur_stream_num(); + auto common_stream_num = resource_manager.cur_stream_num(); if (exit_hcom) { AssignHcom(graph_ptr); } - auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num; + auto hcom_stream_num = resource_manager.cur_stream_num() - common_stream_num; if (exit_independent) { AssignIndependent(graph_ptr); } - auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num; + auto independent_stream_num = resource_manager.cur_stream_num() - common_stream_num - hcom_stream_num; auto total_stream_num = - resource_manager.get_cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum); + resource_manager.cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum); MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num << ", hcom stream number: " << hcom_stream_num << "*" << (kHcomSecondaryStreamNum + 1) << ", independent stream number: " << independent_stream_num << "."; @@ -633,14 +634,14 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull &gra << ", search details information in mindspore's FAQ."; } - MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); + MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.cur_stream_num(); } void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); uint32_t cur_common_stream_id = 0; - uint32_t cur_stream_num = resource_manager.get_cur_stream_num(); + uint32_t cur_stream_num = resource_manager.cur_stream_num(); if (cur_stream_num == 0) { cur_common_stream_id = resource_manager.ApplyNewStream(); } else { @@ -717,7 +718,7 @@ void AscendStreamAssign::AssignHcom(const NotNull &graph_ptr) { uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto task_num = GetHcomTaskNum(cur_cnode_ptr); uint32_t cur_hcom_stream_id; @@ -779,7 +780,7 @@ void AscendStreamAssign::AssignIndependent(const NotNull &graph_ uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); uint32_t cur_independent_stream_id; if (new_graph) { cur_independent_stream_id = resource_manager.ApplyNewStream(); @@ -1226,7 +1227,7 @@ void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); vector cnodes = cnode_ptr_list; uint32_t cur_event_id = resource_manager.ApplyNewEvent(); @@ -1265,12 +1266,12 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNullset_execution_order(cnodes); - MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.cur_event_num(); } // after memory reuse is correct, use this function void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); vector cnodes; CNodePtr cur_cnode_ptr = nullptr; @@ -1318,7 +1319,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNullset_execution_order(cnodes); - MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num(); } vector AscendStreamAssign::GetLastInputCnode(const NotNull &graph_ptr, @@ -1403,7 +1404,7 @@ vector AscendStreamAssign::GetInputKernels(const CNodePtr &cnode) { } void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); vector cnodes; CNodePtr cur_cnode_ptr = nullptr; @@ -1444,7 +1445,7 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNullset_execution_order(cnodes); - MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num(); } void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { @@ -1503,7 +1504,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, const std::vector>> &hcom_index) { vector orders; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); uint32_t cur_event_id = resource_manager.ApplyNewEvent(); if (hcom_index.empty()) { @@ -1591,7 +1592,7 @@ bool AscendStreamAssign::IsSatisfiedHcom(const std::vector &graph_ptr) { MS_LOG(INFO) << "Start"; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); auto cnode_ptr_list = graph_ptr->execution_order(); vector cnodes = cnode_ptr_list; uint32_t cur_event_id = resource_manager.ApplyNewEvent(); @@ -1647,7 +1648,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNullset_execution_order(new_cnodes); - MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.cur_event_num(); MS_LOG(INFO) << "End"; } @@ -1834,7 +1835,7 @@ void AscendStreamAssign::CheckResourceAssign(const NotNull &grap } void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); std::set streams; uint32_t max_stream = 0; uint32_t min_stream = kInvalidStreamId; @@ -1861,7 +1862,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ if (min_stream != 0) { MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream; } - uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); + uint32_t assigned_stream_num = resource_manager.cur_stream_num(); if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); @@ -1870,7 +1871,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ } void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); std::map> event_map; uint32_t max_event_id = 0; uint32_t min_event_id = kInvalidEventId; @@ -1901,7 +1902,7 @@ void AscendStreamAssign::CheckEventAssign(const NotNull &graph_p if (min_event_id != 0) { MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id; } - uint32_t assigned_event_num = resource_manager.get_cur_event_num(); + uint32_t assigned_event_num = resource_manager.cur_event_num(); if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { MS_LOG(EXCEPTION) << "Event should be consecutive, however, assigned event num is: " << assigned_event_num << ", max event id:" << max_event_id << ", event map is:" << event_map; @@ -2024,8 +2025,8 @@ bool AscendStreamAssign::IsTaskSink() { void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { MS_EXCEPTION_IF_NULL(wait_active_stream_list); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t total_stream_num = resource_manager.get_cur_stream_num(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); + uint32_t total_stream_num = resource_manager.cur_stream_num(); if (total_stream_num == 0) { MS_LOG(INFO) << "The total_common_stream_num is zero"; return; @@ -2134,8 +2135,8 @@ void AscendStreamAssign::GetStreamRelations() { } void AscendStreamAssign::FindStreamRelations(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto stream_num = resource_manager.get_cur_stream_num(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); + auto stream_num = resource_manager.cur_stream_num(); if (stream_num <= 1) { return; } @@ -2378,8 +2379,8 @@ bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv } void AscendStreamAssign::FindEventRelations(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto event_nums = resource_manager.get_cur_event_num(); + AscendStreamMng &resource_manager = AscendStreamMng::GetInstance(); + auto event_nums = resource_manager.cur_event_num(); if (event_nums == 0) { return; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index 17d58d07602..f7ac5ab8adf 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -45,59 +45,6 @@ using std::vector; using CNodeKey = void *; const uint32_t kInvalidStreamId = UINT32_MAX; const uint32_t kInvalidEventId = UINT32_MAX; -class AscendResourceMng { - public: - static AscendResourceMng &GetInstance() { - static AscendResourceMng instance; - return instance; - } - - void ResetResource() { - cur_stream_num_ = 0; - cur_event_num_ = 0; - } - uint32_t ApplyNewStream() { - if (!cur_stream_num_) { - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t ApplyNewEvent() { - if (!cur_event_num_) { - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - - void DeleteEvent() { - if (!cur_event_num_) { - MS_LOG(WARNING) << "total event num is 0, no event to delete"; - } else { - --cur_event_num_; - } - } - uint32_t get_cur_stream_num() { return cur_stream_num_; } - uint32_t GetCurAllocStreamId() { - if (!cur_stream_num_) { - MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; - } - return cur_stream_num_ - 1; - } - uint32_t get_cur_event_num() { return cur_event_num_; } - - private: - uint32_t cur_stream_num_{0}; - uint32_t cur_event_num_{0}; -}; - enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail }; class AscendStreamAssign { public: diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_manager.h new file mode 100644 index 00000000000..7f8710ebdd2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_manager.h @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_ + +namespace mindspore { +namespace device { +namespace ascend { +class AscendStreamMng { + public: + static AscendStreamMng &GetInstance() { + static AscendStreamMng instance; + return instance; + } + + void ResetResource() { + cur_stream_num_ = 0; + cur_event_num_ = 0; + } + + uint32_t ApplyNewStream() { return cur_stream_num_++; } + + uint32_t ApplyNewEvent() { return cur_event_num_++; } + + void DeleteEvent() { + if (!cur_event_num_) { + MS_LOG(WARNING) << "total event num is 0, no event to delete"; + } else { + --cur_event_num_; + } + } + + uint32_t cur_stream_num() const { return cur_stream_num_; } + + uint32_t GetCurAllocStreamId() { + if (!cur_stream_num_) { + MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; + } + return cur_stream_num_ - 1; + } + + uint32_t cur_event_num() const { return cur_event_num_; } + + private: + uint32_t cur_stream_num_{0}; + uint32_t cur_event_num_{0}; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index 27f87f9fd0f..a8c24f7d3eb 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -31,7 +31,7 @@ #include "utils/utils.h" #include "runtime/device/ascend/profiling/profiling_manager.h" #include "runtime/base.h" -#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_stream_manager.h" #include "utils/shape_utils.h" namespace { @@ -139,7 +139,7 @@ void KernelAdjust::InsertIndepentParallel(const std::shared_ptr *exec_order) { MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(exec_order); - device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance(); CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch); MS_EXCEPTION_IF_NULL(independent_switch_app); uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream(); @@ -158,7 +158,7 @@ void KernelAdjust::InsertFpBpLoopStreamSwitch(const std::shared_ptr MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(exec_order); MS_EXCEPTION_IF_NULL(eos_done_event_id); - device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance(); *eos_done_event_id = resource_manager.ApplyNewEvent(); CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_done_event_id); AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); @@ -415,7 +415,7 @@ void KernelAdjust::InsertEosDoneSend(const std::shared_ptr void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance(); resource_manager.ResetResource(); if (!NeedInsertSwitch()) { return;