forked from mindspore-Ecosystem/mindspore
!24280 extract stream manager from stream assign
Merge pull request !24280 from laiyongqiang/master
This commit is contained in:
commit
e099bb52d5
|
@ -28,6 +28,7 @@
|
|||
#include "utils/mpi/mpi_config.h"
|
||||
#include "common/trans.h"
|
||||
#include "runtime/rt.h"
|
||||
#include "runtime/device/ascend/ascend_stream_manager.h"
|
||||
#include "runtime/device/ascend/ascend_stream_assign.h"
|
||||
#include "runtime/device/ascend/ge_runtime/model_runner.h"
|
||||
#include "runtime/device/ascend/tasksink/task_generator.h"
|
||||
|
@ -487,20 +488,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph &graph) {
|
|||
return true;
|
||||
}
|
||||
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
// the streams' flag not HEAD_STREAM
|
||||
std::vector<uint32_t> wait_active_stream_list;
|
||||
assign_instance.GetWaitStreams(&wait_active_stream_list);
|
||||
std::vector<uint32_t> force_copy_stream_list;
|
||||
assign_instance.GetHcomStreams(&force_copy_stream_list);
|
||||
MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num()
|
||||
<< ", total event num:" << resource_manager.get_cur_event_num()
|
||||
<< ", total label num:" << graph.label_num()
|
||||
MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.cur_stream_num()
|
||||
<< ", total event num:" << resource_manager.cur_event_num() << ", total label num:" << graph.label_num()
|
||||
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
|
||||
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
|
||||
auto model = std::make_shared<ge::model_runner::DavinciModel>(
|
||||
task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0,
|
||||
resource_manager.get_cur_stream_num(), graph.label_num(), resource_manager.get_cur_event_num(), 0);
|
||||
resource_manager.cur_stream_num(), graph.label_num(), resource_manager.cur_event_num(), 0);
|
||||
auto ret = graph_model_map_.insert(std::make_pair(graph.graph_id(), model));
|
||||
if (!ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "frontend/parallel/device_manager.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_adjust.h"
|
||||
#include "runtime/device/ascend/ascend_stream_manager.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -589,7 +590,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
|||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
bool exit_independent = false;
|
||||
bool exit_hcom = false;
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
|
@ -611,19 +612,19 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
|||
AssignCommonStreamId(cur_cnode_ptr);
|
||||
}
|
||||
|
||||
auto common_stream_num = resource_manager.get_cur_stream_num();
|
||||
auto common_stream_num = resource_manager.cur_stream_num();
|
||||
|
||||
if (exit_hcom) {
|
||||
AssignHcom(graph_ptr);
|
||||
}
|
||||
auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num;
|
||||
auto hcom_stream_num = resource_manager.cur_stream_num() - common_stream_num;
|
||||
|
||||
if (exit_independent) {
|
||||
AssignIndependent(graph_ptr);
|
||||
}
|
||||
auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num;
|
||||
auto independent_stream_num = resource_manager.cur_stream_num() - common_stream_num - hcom_stream_num;
|
||||
auto total_stream_num =
|
||||
resource_manager.get_cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum);
|
||||
resource_manager.cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum);
|
||||
MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num
|
||||
<< ", hcom stream number: " << hcom_stream_num << "*" << (kHcomSecondaryStreamNum + 1)
|
||||
<< ", independent stream number: " << independent_stream_num << ".";
|
||||
|
@ -633,14 +634,14 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
|||
<< ", search details information in mindspore's FAQ.";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
|
||||
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.cur_stream_num();
|
||||
}
|
||||
|
||||
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t cur_common_stream_id = 0;
|
||||
uint32_t cur_stream_num = resource_manager.get_cur_stream_num();
|
||||
uint32_t cur_stream_num = resource_manager.cur_stream_num();
|
||||
if (cur_stream_num == 0) {
|
||||
cur_common_stream_id = resource_manager.ApplyNewStream();
|
||||
} else {
|
||||
|
@ -717,7 +718,7 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|||
|
||||
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto task_num = GetHcomTaskNum(cur_cnode_ptr);
|
||||
|
||||
uint32_t cur_hcom_stream_id;
|
||||
|
@ -779,7 +780,7 @@ void AscendStreamAssign::AssignIndependent(const NotNull<KernelGraphPtr> &graph_
|
|||
|
||||
uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t cur_independent_stream_id;
|
||||
if (new_graph) {
|
||||
cur_independent_stream_id = resource_manager.ApplyNewStream();
|
||||
|
@ -1226,7 +1227,7 @@ void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr
|
|||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
vector<CNodePtr> cnodes = cnode_ptr_list;
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
|
@ -1265,12 +1266,12 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
|||
// one event allocated additional, should delete
|
||||
resource_manager.DeleteEvent();
|
||||
graph_ptr->set_execution_order(cnodes);
|
||||
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
|
||||
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.cur_event_num();
|
||||
}
|
||||
|
||||
// after memory reuse is correct, use this function
|
||||
void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
vector<CNodePtr> cnodes;
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
|
@ -1318,7 +1319,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|||
}
|
||||
|
||||
graph_ptr->set_execution_order(cnodes);
|
||||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
|
||||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num();
|
||||
}
|
||||
|
||||
vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
|
@ -1403,7 +1404,7 @@ vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
vector<CNodePtr> cnodes;
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
|
@ -1444,7 +1445,7 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
|
|||
}
|
||||
|
||||
graph_ptr->set_execution_order(cnodes);
|
||||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
|
||||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num();
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
|
@ -1503,7 +1504,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|||
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) {
|
||||
vector<CNodePtr> orders;
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
if (hcom_index.empty()) {
|
||||
|
@ -1591,7 +1592,7 @@ bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, v
|
|||
// section6
|
||||
void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
MS_LOG(INFO) << "Start";
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
vector<CNodePtr> cnodes = cnode_ptr_list;
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
|
@ -1647,7 +1648,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
|||
}
|
||||
|
||||
graph_ptr->set_execution_order(new_cnodes);
|
||||
MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num();
|
||||
MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.cur_event_num();
|
||||
MS_LOG(INFO) << "End";
|
||||
}
|
||||
|
||||
|
@ -1834,7 +1835,7 @@ void AscendStreamAssign::CheckResourceAssign(const NotNull<KernelGraphPtr> &grap
|
|||
}
|
||||
|
||||
void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
std::set<uint32_t> streams;
|
||||
uint32_t max_stream = 0;
|
||||
uint32_t min_stream = kInvalidStreamId;
|
||||
|
@ -1861,7 +1862,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
|
|||
if (min_stream != 0) {
|
||||
MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream;
|
||||
}
|
||||
uint32_t assigned_stream_num = resource_manager.get_cur_stream_num();
|
||||
uint32_t assigned_stream_num = resource_manager.cur_stream_num();
|
||||
if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) {
|
||||
MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream
|
||||
<< "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size();
|
||||
|
@ -1870,7 +1871,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
|
|||
}
|
||||
|
||||
void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
std::map<uint32_t, std::vector<CNodePtr>> event_map;
|
||||
uint32_t max_event_id = 0;
|
||||
uint32_t min_event_id = kInvalidEventId;
|
||||
|
@ -1901,7 +1902,7 @@ void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_p
|
|||
if (min_event_id != 0) {
|
||||
MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id;
|
||||
}
|
||||
uint32_t assigned_event_num = resource_manager.get_cur_event_num();
|
||||
uint32_t assigned_event_num = resource_manager.cur_event_num();
|
||||
if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) {
|
||||
MS_LOG(EXCEPTION) << "Event should be consecutive, however, assigned event num is: " << assigned_event_num
|
||||
<< ", max event id:" << max_event_id << ", event map is:" << event_map;
|
||||
|
@ -2024,8 +2025,8 @@ bool AscendStreamAssign::IsTaskSink() {
|
|||
|
||||
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
|
||||
MS_EXCEPTION_IF_NULL(wait_active_stream_list);
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
uint32_t total_stream_num = resource_manager.get_cur_stream_num();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t total_stream_num = resource_manager.cur_stream_num();
|
||||
if (total_stream_num == 0) {
|
||||
MS_LOG(INFO) << "The total_common_stream_num is zero";
|
||||
return;
|
||||
|
@ -2134,8 +2135,8 @@ void AscendStreamAssign::GetStreamRelations() {
|
|||
}
|
||||
|
||||
void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto stream_num = resource_manager.get_cur_stream_num();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto stream_num = resource_manager.cur_stream_num();
|
||||
if (stream_num <= 1) {
|
||||
return;
|
||||
}
|
||||
|
@ -2378,8 +2379,8 @@ bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv
|
|||
}
|
||||
|
||||
void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto event_nums = resource_manager.get_cur_event_num();
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
auto event_nums = resource_manager.cur_event_num();
|
||||
if (event_nums == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -45,59 +45,6 @@ using std::vector;
|
|||
using CNodeKey = void *;
|
||||
const uint32_t kInvalidStreamId = UINT32_MAX;
|
||||
const uint32_t kInvalidEventId = UINT32_MAX;
|
||||
class AscendResourceMng {
|
||||
public:
|
||||
static AscendResourceMng &GetInstance() {
|
||||
static AscendResourceMng instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ResetResource() {
|
||||
cur_stream_num_ = 0;
|
||||
cur_event_num_ = 0;
|
||||
}
|
||||
uint32_t ApplyNewStream() {
|
||||
if (!cur_stream_num_) {
|
||||
uint32_t cur_stream_id = cur_stream_num_;
|
||||
cur_stream_num_++;
|
||||
return cur_stream_id;
|
||||
}
|
||||
uint32_t cur_stream_id = cur_stream_num_;
|
||||
cur_stream_num_++;
|
||||
return cur_stream_id;
|
||||
}
|
||||
uint32_t ApplyNewEvent() {
|
||||
if (!cur_event_num_) {
|
||||
uint32_t cur_event_id = cur_event_num_;
|
||||
cur_event_num_++;
|
||||
return cur_event_id;
|
||||
}
|
||||
uint32_t cur_event_id = cur_event_num_;
|
||||
cur_event_num_++;
|
||||
return cur_event_id;
|
||||
}
|
||||
|
||||
void DeleteEvent() {
|
||||
if (!cur_event_num_) {
|
||||
MS_LOG(WARNING) << "total event num is 0, no event to delete";
|
||||
} else {
|
||||
--cur_event_num_;
|
||||
}
|
||||
}
|
||||
uint32_t get_cur_stream_num() { return cur_stream_num_; }
|
||||
uint32_t GetCurAllocStreamId() {
|
||||
if (!cur_stream_num_) {
|
||||
MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
|
||||
}
|
||||
return cur_stream_num_ - 1;
|
||||
}
|
||||
uint32_t get_cur_event_num() { return cur_event_num_; }
|
||||
|
||||
private:
|
||||
uint32_t cur_stream_num_{0};
|
||||
uint32_t cur_event_num_{0};
|
||||
};
|
||||
|
||||
enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail };
|
||||
class AscendStreamAssign {
|
||||
public:
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
class AscendStreamMng {
|
||||
public:
|
||||
static AscendStreamMng &GetInstance() {
|
||||
static AscendStreamMng instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void ResetResource() {
|
||||
cur_stream_num_ = 0;
|
||||
cur_event_num_ = 0;
|
||||
}
|
||||
|
||||
uint32_t ApplyNewStream() { return cur_stream_num_++; }
|
||||
|
||||
uint32_t ApplyNewEvent() { return cur_event_num_++; }
|
||||
|
||||
void DeleteEvent() {
|
||||
if (!cur_event_num_) {
|
||||
MS_LOG(WARNING) << "total event num is 0, no event to delete";
|
||||
} else {
|
||||
--cur_event_num_;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t cur_stream_num() const { return cur_stream_num_; }
|
||||
|
||||
uint32_t GetCurAllocStreamId() {
|
||||
if (!cur_stream_num_) {
|
||||
MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
|
||||
}
|
||||
return cur_stream_num_ - 1;
|
||||
}
|
||||
|
||||
uint32_t cur_event_num() const { return cur_event_num_; }
|
||||
|
||||
private:
|
||||
uint32_t cur_stream_num_{0};
|
||||
uint32_t cur_event_num_{0};
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_STREAM_MANAGER_H_
|
|
@ -31,7 +31,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_manager.h"
|
||||
#include "runtime/base.h"
|
||||
#include "runtime/device/ascend/ascend_stream_assign.h"
|
||||
#include "runtime/device/ascend/ascend_stream_manager.h"
|
||||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace {
|
||||
|
@ -139,7 +139,7 @@ void KernelAdjust::InsertIndepentParallel(const std::shared_ptr<session::KernelG
|
|||
std::vector<CNodePtr> *exec_order) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch);
|
||||
MS_EXCEPTION_IF_NULL(independent_switch_app);
|
||||
uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
|
@ -158,7 +158,7 @@ void KernelAdjust::InsertFpBpLoopStreamSwitch(const std::shared_ptr<session::Ker
|
|||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_stream_id);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_switch_stream_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*fpbp_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
*fpbp_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch);
|
||||
|
@ -296,7 +296,7 @@ void KernelAdjust::InsertGetNextLoopStreamSwitch(
|
|||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(getnext_switch_stream_id);
|
||||
MS_EXCEPTION_IF_NULL(getnext_stream_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*getnext_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
*getnext_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
|
||||
|
@ -331,7 +331,7 @@ void KernelAdjust::InsertGetNextLoopFpBpStartSend(const std::shared_ptr<session:
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_start_event_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*fpbp_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, *fpbp_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
|
||||
|
@ -345,7 +345,7 @@ void KernelAdjust::InsertGetNextLoopEosStartSend(const std::shared_ptr<session::
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(eos_start_event_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*eos_start_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_start_event_id);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
|
||||
|
@ -361,7 +361,7 @@ void KernelAdjust::InsertEosStreamSwitch(const std::shared_ptr<session::KernelGr
|
|||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(eos_switch_stream_id);
|
||||
MS_EXCEPTION_IF_NULL(eos_stream_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*eos_switch_stream_id = resource_manager.ApplyNewStream();
|
||||
*eos_stream_id = resource_manager.ApplyNewStream();
|
||||
CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch);
|
||||
|
@ -405,7 +405,7 @@ void KernelAdjust::InsertEosDoneSend(const std::shared_ptr<session::KernelGraph>
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(exec_order);
|
||||
MS_EXCEPTION_IF_NULL(eos_done_event_id);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
*eos_done_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_done_event_id);
|
||||
AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get());
|
||||
|
@ -415,7 +415,7 @@ void KernelAdjust::InsertEosDoneSend(const std::shared_ptr<session::KernelGraph>
|
|||
|
||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
|
||||
device::ascend::AscendStreamMng &resource_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
resource_manager.ResetResource();
|
||||
if (!NeedInsertSwitch()) {
|
||||
return;
|
||||
|
|
Loading…
Reference in New Issue