forked from mindspore-Ecosystem/mindspore
!2056 Enable new control sink
Merge pull request !2056 from zhoufeng/enable-new-control-sink
This commit is contained in:
commit
53654f94f2
|
@ -340,15 +340,17 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance();
|
||||
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance();
|
||||
// the streams' flag not HEAD_STREAM
|
||||
std::vector<uint32_t> wait_active_stream_list;
|
||||
stream_assign_instance.GetWaitStreams(&wait_active_stream_list);
|
||||
auto force_copy_stream_list = stream_assign_instance.hcom_streams();
|
||||
assign_instance.GetWaitStreams(&wait_active_stream_list);
|
||||
std::vector<uint32_t> force_copy_stream_list;
|
||||
assign_instance.GetHcomStreams(&force_copy_stream_list);
|
||||
|
||||
MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum()
|
||||
<< ", total event num:" << stream_assign_instance.total_event_num()
|
||||
MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum()
|
||||
<< ", total event num:" << assign_instance.total_event_num()
|
||||
<< ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph))
|
||||
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
|
||||
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
|
||||
|
@ -356,8 +358,8 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
|
||||
std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
|
||||
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
|
||||
0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
|
||||
stream_assign_instance.total_event_num(), 0);
|
||||
0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
|
||||
assign_instance.total_event_num(), 0);
|
||||
|
||||
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
|
||||
if (!ret.second) {
|
||||
|
|
|
@ -33,238 +33,220 @@ namespace device {
|
|||
namespace ascend {
|
||||
const uint32_t kHcomMaxTask = 5;
|
||||
const uint32_t kCommonMaxTask = 350;
|
||||
const uint32_t kIndependFirstStreamId = 1024;
|
||||
|
||||
bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) {
|
||||
MS_EXCEPTION_IF_NULL(apply_kernel);
|
||||
return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL;
|
||||
void AscendStreamAssign::AssignStream(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
if (IsTaskSink()) {
|
||||
Reset();
|
||||
ReorderIndependentOrders(graph_ptr);
|
||||
AssignAllNodesStream(graph_ptr);
|
||||
UpdateAtomicAddrCleanStreamId(graph_ptr);
|
||||
FindHcomParallelStreams(graph_ptr);
|
||||
InsertStreamActive(graph_ptr);
|
||||
InsertSendRecvForHcomParallel(graph_ptr);
|
||||
InsertSendRecvForIndependent(graph_ptr);
|
||||
UpdateEventId(graph_ptr);
|
||||
GetNeedActiveStreams(graph_ptr);
|
||||
graph_ptr->PrintGraphExecuteOrder();
|
||||
CheckStreamAssign(graph_ptr);
|
||||
MS_LOG(INFO) << "after finish stream assign";
|
||||
|
||||
// Get info for D Model
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
generator::IRModelUtil::GetInstance().set_event_num(total_event_num());
|
||||
generator::IRModelUtil::GetInstance().set_stream_num(stream_manager.GetCurAllocStreamNum());
|
||||
// Init to 1,temporarily
|
||||
generator::IRModelUtil::GetInstance().set_batch_num(1);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::ResetNew() {
|
||||
total_common_stream_num_ = 0;
|
||||
total_independ_stream_num_ = 0;
|
||||
total_event_num_ = 0;
|
||||
first_physic_id_ = UINT32_MAX;
|
||||
first_logic_id_ = UINT32_MAX;
|
||||
independent_id_ = kIndependFirstStreamId;
|
||||
logic_to_independent_map_.clear();
|
||||
processed_logic_id_.clear();
|
||||
logic_to_physic_map_.clear();
|
||||
independent_before_physic_id_.clear();
|
||||
inner_parallel_streams_.clear();
|
||||
processed_parallel_streams_.clear();
|
||||
hcom_stream_list_.clear();
|
||||
need_first_active_streams_.clear();
|
||||
}
|
||||
// section 0
|
||||
void AscendStreamAssign::CheckStreamAssign(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
std::set<uint32_t> streams;
|
||||
uint32_t max_stream = 0;
|
||||
uint32_t min_stream = kInvalidStreamId;
|
||||
const std::vector<CNodePtr> &cnode_ptr_list = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
if (stream_id == kInvalidStreamId) {
|
||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "] had not been assigned streams";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
auto it = logic_to_independent_map_.find(processing_logic_id);
|
||||
if (it == logic_to_independent_map_.end()) {
|
||||
(void)logic_to_independent_map_.insert(std::make_pair(processing_logic_id, independent_id_));
|
||||
AnfAlgo::SetStreamId(independent_id_, cur_cnode_ptr.get());
|
||||
independent_id_++;
|
||||
} else {
|
||||
AnfAlgo::SetStreamId(it->second, cur_cnode_ptr.get());
|
||||
streams.emplace(stream_id);
|
||||
if (stream_id > max_stream) {
|
||||
max_stream = stream_id;
|
||||
}
|
||||
if (stream_id < min_stream) {
|
||||
min_stream = stream_id;
|
||||
}
|
||||
}
|
||||
|
||||
if (first_physic_id_ == UINT32_MAX) {
|
||||
auto res = std::find(independent_before_physic_id_.begin(), independent_before_physic_id_.end(),
|
||||
AnfAlgo::GetStreamId(cur_cnode_ptr));
|
||||
if (res == independent_before_physic_id_.end()) {
|
||||
independent_before_physic_id_.push_back(AnfAlgo::GetStreamId(cur_cnode_ptr));
|
||||
if (!streams.empty()) {
|
||||
if (min_stream != 0) {
|
||||
MS_LOG(EXCEPTION) << "before stream assign, assigned stream should start from 0, now is from " << min_stream;
|
||||
}
|
||||
if (max_stream != (streams.size() - 1)) {
|
||||
MS_LOG(EXCEPTION) << "before stream assign, assigned stream should be consecutive";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr,
|
||||
uint32_t *cur_index, uint32_t *cur_stream_id) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(*pre_cnode_ptr);
|
||||
bool over_max_hcom_task = (IsHcom(cur_cnode_ptr) && (*cur_index) % kHcomMaxTask == 0);
|
||||
bool over_max_common_task = (!IsHcom(cur_cnode_ptr) && (*cur_index) % kCommonMaxTask == 0);
|
||||
bool pre_common_cur_hcom = (IsHcom(cur_cnode_ptr) && !IsHcom(*pre_cnode_ptr));
|
||||
bool pre_hcom_cur_common = (!IsHcom(cur_cnode_ptr) && IsHcom(*pre_cnode_ptr));
|
||||
if (over_max_hcom_task || over_max_common_task || pre_common_cur_hcom || pre_hcom_cur_common) {
|
||||
*cur_index = 0;
|
||||
++(*cur_stream_id);
|
||||
}
|
||||
|
||||
if (over_max_hcom_task || pre_common_cur_hcom) {
|
||||
hcom_stream_list_.emplace_back(*cur_stream_id);
|
||||
}
|
||||
++(*cur_index);
|
||||
AnfAlgo::SetStreamId(*cur_stream_id, cur_cnode_ptr.get());
|
||||
*pre_cnode_ptr = cur_cnode_ptr;
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsProcessed(uint32_t logic_id) {
|
||||
auto it = std::find(processed_logic_id_.begin(), processed_logic_id_.end(), logic_id);
|
||||
if (it == processed_logic_id_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::RecordIdMap(uint32_t logic_id, uint32_t physic_id) {
|
||||
auto it = logic_to_physic_map_.find(logic_id);
|
||||
if (it == logic_to_physic_map_.end()) {
|
||||
MS_LOG(INFO) << "New logic_id[" << logic_id << "] to physic_id[" << physic_id << "]";
|
||||
(void)logic_to_physic_map_.insert(std::make_pair(logic_id, physic_id));
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id,
|
||||
uint32_t cur_stream_id) {
|
||||
AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get());
|
||||
RecordIdMap(cur_node_logic_id, cur_stream_id);
|
||||
first_physic_id_ = cur_stream_id;
|
||||
first_logic_id_ = cur_node_logic_id;
|
||||
}
|
||||
|
||||
uint32_t AscendStreamAssign::GetLogicId(const CNodePtr &cur_cnode_ptr) {
|
||||
uint32_t logic_id = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get());
|
||||
if (logic_id == kInvalidDistincLabel) {
|
||||
MS_LOG(EXCEPTION) << "node[" << cur_cnode_ptr->DebugString() << "] logic id is invalid";
|
||||
}
|
||||
return logic_id;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::SetCommonStreamNum(uint32_t cur_stream_id) {
|
||||
if (first_physic_id_ == UINT32_MAX) {
|
||||
MS_LOG(INFO) << "cur common node size is zero";
|
||||
total_common_stream_num_ = 0;
|
||||
} else {
|
||||
total_common_stream_num_ = cur_stream_id + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// section 1
|
||||
void AscendStreamAssign::AssignAllNodesStream(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
CNodePtr pre_cnode_ptr = nullptr;
|
||||
uint32_t cur_index = 0;
|
||||
uint32_t cur_stream_id = 0;
|
||||
uint32_t processing_logic_id = UINT32_MAX;
|
||||
|
||||
bool exit_independent = false;
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
// get logic id
|
||||
uint32_t cur_node_logic_id = GetLogicId(cur_cnode_ptr);
|
||||
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||
AssignIndependentStreamId(cur_cnode_ptr, cur_node_logic_id);
|
||||
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
||||
continue;
|
||||
}
|
||||
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||
exit_independent = true;
|
||||
continue;
|
||||
}
|
||||
// first common node, only exe one time
|
||||
if (pre_cnode_ptr == nullptr) {
|
||||
RecordFirstCommonOp(cur_cnode_ptr, cur_node_logic_id, cur_stream_id);
|
||||
processing_logic_id = cur_node_logic_id;
|
||||
uint32_t cur_stream_num = stream_manager.GetCurAllocStreamNum();
|
||||
if (cur_stream_num == 0) {
|
||||
cur_stream_id = stream_manager.ApplyNewStream();
|
||||
} else {
|
||||
cur_stream_id = stream_manager.GetCurAllocStream();
|
||||
}
|
||||
++cur_index;
|
||||
pre_cnode_ptr = cur_cnode_ptr;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1.has been processed
|
||||
if (IsProcessed(cur_node_logic_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cur_node_logic_id == processing_logic_id) {
|
||||
AssignCommonStreamId(cur_cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id);
|
||||
} else {
|
||||
// 1.find other same logic id
|
||||
for (size_t j = i; j < cnode_ptr_list.size(); ++j) {
|
||||
CNodePtr cnode_ptr = cnode_ptr_list[j];
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
uint32_t logic_id = AnfAlgo::GetStreamDistinctionLabel(cnode_ptr.get());
|
||||
if (logic_id == processing_logic_id) {
|
||||
AssignCommonStreamId(cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id);
|
||||
}
|
||||
}
|
||||
// 2.after deal:
|
||||
processed_logic_id_.push_back(processing_logic_id);
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
// 3. new stream
|
||||
++cur_stream_id;
|
||||
AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get());
|
||||
cur_index = 1;
|
||||
|
||||
pre_cnode_ptr = cur_cnode_ptr;
|
||||
processing_logic_id = cur_node_logic_id;
|
||||
RecordIdMap(processing_logic_id, cur_stream_id);
|
||||
if (IsHcom(cur_cnode_ptr)) {
|
||||
hcom_stream_list_.emplace(cur_stream_id);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
AssignCommonStreamId(cur_cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id);
|
||||
}
|
||||
|
||||
SetCommonStreamNum(cur_stream_id);
|
||||
total_independ_stream_num_ = independent_id_ - kIndependFirstStreamId;
|
||||
MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_;
|
||||
if (exit_independent) {
|
||||
uint32_t first_independent_stream_id = stream_manager.ApplyNewStream();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
|
||||
continue;
|
||||
}
|
||||
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||
AssignIndependentStreamId(cur_cnode_ptr);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "independent start from :" << first_independent_stream_id;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "total stream nums:" << stream_manager.GetCurAllocStreamNum();
|
||||
}
|
||||
|
||||
void AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids) {
|
||||
for (auto &id : logic_ids) {
|
||||
auto it = logic_to_physic_map_.find(id);
|
||||
if (it != logic_to_physic_map_.end()) {
|
||||
MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]";
|
||||
(*physic_ids).push_back(it->second);
|
||||
void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t cur_independent_id = stream_manager.GetCurAllocStream();
|
||||
auto it = independent_stream_map_.find(cur_independent_id);
|
||||
if (it == independent_stream_map_.end()) {
|
||||
AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get());
|
||||
independent_stream_map_.emplace(cur_independent_id, 1);
|
||||
} else {
|
||||
if (it->second < kCommonMaxTask) {
|
||||
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
||||
it->second++;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id";
|
||||
}
|
||||
|
||||
auto it_independ = logic_to_independent_map_.find(id);
|
||||
if (it_independ != logic_to_independent_map_.end()) {
|
||||
MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]";
|
||||
(*physic_ids).push_back(it_independ->second);
|
||||
cur_independent_id = stream_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get());
|
||||
independent_stream_map_.emplace(cur_independent_id, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) {
|
||||
MS_LOG(INFO) << "start update outter active op[" << active_ptr->DebugString() << "] ";
|
||||
MS_EXCEPTION_IF_NULL(active_ptr);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(active_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
vector<uint32_t> active_logic_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList));
|
||||
// out StreamAcitve active physic stream is not parallel now, if parallel, should deal here.
|
||||
vector<uint32_t> active_physic_ids;
|
||||
TransLogicToPhysic(active_logic_ids, &active_physic_ids);
|
||||
ValuePtr active_physic_value = MakeValue<std::vector<uint32_t>>(active_physic_ids);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr);
|
||||
}
|
||||
|
||||
void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr) {
|
||||
MS_LOG(INFO) << "start update switch op[" << switch_ptr->DebugString() << "]";
|
||||
MS_EXCEPTION_IF_NULL(switch_ptr);
|
||||
MS_EXCEPTION_IF_NULL(active_ptr);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto true_logic_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
|
||||
MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id
|
||||
<< "]";
|
||||
vector<uint32_t> logic_ids{true_logic_id};
|
||||
vector<uint32_t> physic_ids;
|
||||
TransLogicToPhysic(logic_ids, &physic_ids);
|
||||
if (physic_ids.empty()) {
|
||||
MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id";
|
||||
bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||
if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) {
|
||||
return false;
|
||||
}
|
||||
ValuePtr true_index = MakeValue<uint32_t>(physic_ids[0]);
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_ptr);
|
||||
|
||||
MS_LOG(INFO) << "start update StreamActive op[" << active_ptr->DebugString() << "]";
|
||||
AnfAlgo::SetStreamId(physic_ids[0], active_ptr.get());
|
||||
vector<uint32_t> active_ids;
|
||||
for (size_t i = 0; i < physic_ids.size(); i++) {
|
||||
if (i == 0) {
|
||||
MS_LOG(INFO) << "StreamActive op self stream id[" << physic_ids[i] << "]";
|
||||
} else {
|
||||
MS_LOG(INFO) << "StreamActive op active stream id[" << physic_ids[i] << "]";
|
||||
active_ids.emplace_back(physic_ids[i]);
|
||||
if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) {
|
||||
MS_LOG(INFO) << "GetNext should not be independent node";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr);
|
||||
if (input_nums == 0) {
|
||||
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero";
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<AnfNodePtr> &inputs = node_ptr->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (!inputs[i]->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
|
||||
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node";
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr,
|
||||
uint32_t *cur_index, uint32_t *cur_stream_id) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(*pre_cnode_ptr);
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
bool over_max_hcom_task = (IsHcom(cur_cnode_ptr) && (*cur_index) % kHcomMaxTask == 0);
|
||||
bool over_max_common_task = (!IsHcom(cur_cnode_ptr) && (*cur_index) % kCommonMaxTask == 0);
|
||||
bool pre_common_cur_hcom = (IsHcom(cur_cnode_ptr) && !IsHcom(*pre_cnode_ptr));
|
||||
bool pre_hcom_cur_common = (!IsHcom(cur_cnode_ptr) && IsHcom(*pre_cnode_ptr));
|
||||
if (over_max_hcom_task || over_max_common_task || pre_common_cur_hcom || pre_hcom_cur_common) {
|
||||
*cur_index = 0;
|
||||
*cur_stream_id = stream_manager.ApplyNewStream();
|
||||
}
|
||||
|
||||
++(*cur_index);
|
||||
AnfAlgo::SetStreamId(*cur_stream_id, cur_cnode_ptr.get());
|
||||
*pre_cnode_ptr = cur_cnode_ptr;
|
||||
|
||||
// record ll hcom streams as hcom stream has different stream flag
|
||||
if (IsHcom(cur_cnode_ptr)) {
|
||||
auto it = std::find(hcom_stream_list_.begin(), hcom_stream_list_.end(), *cur_stream_id);
|
||||
if (it == hcom_stream_list_.end()) {
|
||||
MS_LOG(INFO) << "hcom stream id:" << *cur_stream_id;
|
||||
hcom_stream_list_.emplace(*cur_stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// section 2:
|
||||
void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
const std::vector<CNodePtr> &cnode_ptr_list = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
// update AtomicAddrClean stream same witch the next node
|
||||
if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
|
||||
MS_LOG(INFO) << "update AtomicAddrClean stream id from[" << AnfAlgo::GetStreamId(cnode_ptr_list[i - 1])
|
||||
<< "] to [" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]";
|
||||
AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get());
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
// section 3
|
||||
void AscendStreamAssign::FindHcomParallelStreams(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
CNodePtr pre_cnode_ptr = nullptr;
|
||||
|
@ -280,9 +262,9 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
|
|||
continue;
|
||||
}
|
||||
|
||||
bool diff_stream = (pre_stream_id != cur_stream_id) && (pre_stream_id < cur_stream_id);
|
||||
bool pre_hcom = IsHcom(pre_cnode_ptr);
|
||||
if (diff_stream && pre_hcom) {
|
||||
bool pre_fusion_hcom = IsFusionHcom(pre_cnode_ptr);
|
||||
bool diff_stream = (pre_stream_id != cur_stream_id);
|
||||
if (diff_stream && pre_fusion_hcom) {
|
||||
inner_parallel_streams_.emplace_back(std::vector<uint32_t>{pre_stream_id, cur_stream_id});
|
||||
}
|
||||
|
||||
|
@ -291,6 +273,138 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG
|
|||
}
|
||||
}
|
||||
|
||||
// section 4
|
||||
void AscendStreamAssign::UpdateStreamSwitch(const std::shared_ptr<session::KernelGraph> &graph_ptr,
|
||||
const CNodePtr &switch_ptr, const vector<uint32_t> &independent_stream,
|
||||
vector<CNodePtr> *orders) {
|
||||
MS_EXCEPTION_IF_NULL(orders);
|
||||
orders->emplace_back(switch_ptr);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
|
||||
if (value_ptr == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto need_active = GetValue<bool>(value_ptr);
|
||||
if (!need_active) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "start update switch op[" << switch_ptr->DebugString() << "]";
|
||||
MS_EXCEPTION_IF_NULL(switch_ptr);
|
||||
auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
|
||||
MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_stream_id
|
||||
<< "]";
|
||||
|
||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||
MS_LOG(INFO) << "start update StreamActive op[" << active_ptr->DebugString() << "]";
|
||||
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(independent_stream), active_ptr);
|
||||
independent_stream_activated_ = true;
|
||||
|
||||
// update processed stream
|
||||
for (auto &item : independent_stream) {
|
||||
processed_streams_.emplace(item);
|
||||
}
|
||||
|
||||
orders->emplace_back(active_ptr);
|
||||
} // namespace ascend
|
||||
|
||||
void AscendStreamAssign::InsertStreamActive(const std::shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
std::vector<CNodePtr> update_cnode_list;
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
CNodePtr pre_cnode_ptr = nullptr;
|
||||
uint32_t pre_stream_id = UINT32_MAX;
|
||||
std::vector<uint32_t> independent_stream;
|
||||
MS_LOG(INFO) << "independent stream size:" << independent_stream_map_.size();
|
||||
for (auto item : independent_stream_map_) {
|
||||
independent_stream.emplace_back(item.first);
|
||||
}
|
||||
|
||||
bool independent_flag = !(independent_stream.empty());
|
||||
|
||||
const std::vector<CNodePtr> &cnode_ptr_list = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
if (IsIndependentNode(cur_cnode_ptr)) {
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
continue;
|
||||
}
|
||||
|
||||
bool inner_active = false;
|
||||
if (pre_cnode_ptr != nullptr) {
|
||||
inner_active = pre_stream_id != cur_stream_id && AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName &&
|
||||
AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName;
|
||||
}
|
||||
|
||||
bool processed = IsProcessedStream(cur_stream_id);
|
||||
// 1)inner stream assign, need insert active op
|
||||
if (inner_active && !processed) {
|
||||
MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]";
|
||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||
// 1.set stream id
|
||||
AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
|
||||
// 2.set active stream ids
|
||||
std::vector<uint32_t> active_index_list;
|
||||
GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
|
||||
update_cnode_list.emplace_back(active_ptr);
|
||||
}
|
||||
|
||||
if (independent_flag && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) {
|
||||
MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
|
||||
UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, independent_stream, &update_cnode_list);
|
||||
} else {
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
}
|
||||
|
||||
processed_streams_.emplace(cur_stream_id);
|
||||
pre_stream_id = cur_stream_id;
|
||||
pre_cnode_ptr = cur_cnode_ptr;
|
||||
}
|
||||
graph_ptr->set_execution_order(update_cnode_list);
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
|
||||
auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id);
|
||||
if (it != processed_streams_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id,
|
||||
vector<uint32_t> *parallel_streams) {
|
||||
MS_EXCEPTION_IF_NULL(parallel_streams);
|
||||
for (size_t i = 0; i < inner_parallel_streams_.size(); i++) {
|
||||
const auto &cur_parallel_streams = inner_parallel_streams_[i];
|
||||
auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id);
|
||||
if (it != cur_parallel_streams.end()) {
|
||||
MS_LOG(INFO) << "stream id:" << cur_stream_id << " is parallel stream";
|
||||
for (size_t j = 0; j < cur_parallel_streams.size(); j++) {
|
||||
if (cur_parallel_streams[j] == stream_acitve_id) {
|
||||
MS_LOG(INFO) << "one of parallel stream id" << cur_parallel_streams[j]
|
||||
<< "is same with streamacvite stream id" << stream_acitve_id;
|
||||
continue;
|
||||
}
|
||||
(*parallel_streams).emplace_back(cur_parallel_streams[j]);
|
||||
processed_streams_.emplace(cur_parallel_streams[j]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
processed_streams_.emplace(cur_stream_id);
|
||||
(*parallel_streams).push_back(cur_stream_id);
|
||||
}
|
||||
|
||||
// section5
|
||||
void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
|
@ -299,7 +413,7 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
|
|||
vector<CNodePtr> orders;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
if (IsHcom(cur_cnode)) {
|
||||
if (IsFusionHcom(cur_cnode)) {
|
||||
fusion_hcom_index.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -310,7 +424,7 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
|
|||
uint32_t first_index = fusion_hcom_index[0];
|
||||
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];
|
||||
uint32_t cur_event_id = total_event_num_;
|
||||
uint32_t pre_hcom_stream_id = UINT32_MAX;
|
||||
uint32_t pre_hcom_stream_id = kInvalidStreamId;
|
||||
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
|
||||
for (size_t i = first_index; i <= last_index; i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
|
@ -362,6 +476,11 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
|
|||
MS_EXCEPTION_IF_NULL(*it);
|
||||
MS_EXCEPTION_IF_NULL(*(it + 1));
|
||||
if (IsHcom(*it) && !IsHcom(*(it + 1))) {
|
||||
bool is_fusion = IsFusionHcom(*it);
|
||||
if (!is_fusion) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||
|
||||
|
@ -390,99 +509,6 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor
|
|||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) {
|
||||
auto it = std::find(processed_parallel_streams_.begin(), processed_parallel_streams_.end(), stream_id);
|
||||
if (it != processed_parallel_streams_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id,
|
||||
vector<uint32_t> *parallel_streams) {
|
||||
for (size_t i = 0; i < inner_parallel_streams_.size(); i++) {
|
||||
auto cur_parallel_streams = inner_parallel_streams_[i];
|
||||
auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id);
|
||||
if (it != cur_parallel_streams.end()) {
|
||||
MS_LOG(INFO) << "stream id:" << cur_stream_id << " is parallel stream";
|
||||
for (size_t j = 0; j < cur_parallel_streams.size(); j++) {
|
||||
if (cur_parallel_streams[j] == stream_acitve_id) {
|
||||
MS_LOG(INFO) << "one of parallel stream id" << cur_parallel_streams[j]
|
||||
<< "is same with streamacvite stream id" << stream_acitve_id;
|
||||
continue;
|
||||
}
|
||||
(*parallel_streams).emplace_back(cur_parallel_streams[j]);
|
||||
}
|
||||
|
||||
// record processed parallel streams
|
||||
(void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(),
|
||||
std::back_inserter(processed_parallel_streams_));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
(*parallel_streams).push_back(cur_stream_id);
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
std::vector<CNodePtr> update_cnode_list;
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
CNodePtr pre_cnode_ptr = nullptr;
|
||||
uint32_t pre_stream_id = UINT32_MAX;
|
||||
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
if (cur_stream_id >= kIndependFirstStreamId) {
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
continue;
|
||||
}
|
||||
|
||||
bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id &&
|
||||
AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName &&
|
||||
AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamActiveOpName &&
|
||||
AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName;
|
||||
bool processed = IsProcessedParallelStream(cur_stream_id);
|
||||
// 1)inner stream assign, need insert active op
|
||||
if (inner_active && !processed) {
|
||||
MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]";
|
||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||
update_cnode_list.emplace_back(active_ptr);
|
||||
// 1.set stream id
|
||||
AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
|
||||
// 2.set active stream ids
|
||||
std::vector<uint32_t> active_index_list;
|
||||
GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
|
||||
}
|
||||
// inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8)
|
||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName &&
|
||||
AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) {
|
||||
// 2)outter stream assign, update active op
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
UpdateStreamActive(cur_cnode_ptr);
|
||||
} else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
||||
// 3)update switch op
|
||||
MS_LOG(INFO) << "Insert active op after switch";
|
||||
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
update_cnode_list.emplace_back(active_ptr);
|
||||
UpdateStreamSwitch(cur_cnode_ptr, active_ptr);
|
||||
} else {
|
||||
update_cnode_list.emplace_back(cur_cnode_ptr);
|
||||
}
|
||||
|
||||
pre_stream_id = cur_stream_id;
|
||||
pre_cnode_ptr = cur_cnode_ptr;
|
||||
}
|
||||
graph_ptr->set_execution_order(update_cnode_list);
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::UpdateEventId(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
|
@ -514,64 +540,11 @@ void AscendStreamAssign::UpdateEventId(const shared_ptr<session::KernelGraph> &g
|
|||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
if (cur_stream_id < kIndependFirstStreamId) {
|
||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
vector<uint32_t> active_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList));
|
||||
for (size_t j = 0; j < active_ids.size(); j++) {
|
||||
if (active_ids[j] >= kIndependFirstStreamId) {
|
||||
active_ids[j] = active_ids[j] - kIndependFirstStreamId + total_common_stream_num_;
|
||||
}
|
||||
}
|
||||
ValuePtr active_value = MakeValue<std::vector<uint32_t>>(active_ids);
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_value, cur_cnode_ptr);
|
||||
}
|
||||
} else {
|
||||
uint32_t update_id = cur_stream_id - kIndependFirstStreamId + total_common_stream_num_;
|
||||
AnfAlgo::SetStreamId(update_id, cur_cnode_ptr.get());
|
||||
}
|
||||
|
||||
// update AtomicAddrClean stream same witch the next node
|
||||
if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == "AtomicAddrClean") {
|
||||
MS_LOG(INFO) << "update AtomicAddrClean stream id from[" << AnfAlgo::GetStreamId(cnode_ptr_list[i - 1])
|
||||
<< "] to [" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]";
|
||||
AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get());
|
||||
}
|
||||
}
|
||||
|
||||
// update logic_to_independent_map_
|
||||
for (auto &indep : logic_to_independent_map_) {
|
||||
if (indep.second >= kIndependFirstStreamId) {
|
||||
indep.second = indep.second - kIndependFirstStreamId + total_common_stream_num_;
|
||||
}
|
||||
}
|
||||
|
||||
// update independent_before_physic_id_
|
||||
for (auto &id : independent_before_physic_id_) {
|
||||
if (id >= kIndependFirstStreamId) {
|
||||
id = id - kIndependFirstStreamId + total_common_stream_num_;
|
||||
}
|
||||
}
|
||||
|
||||
// update independent_id_
|
||||
independent_id_ = independent_id_ - kIndependFirstStreamId + total_common_stream_num_;
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
// 1)stream witch kStreamNeedActivedFirst attr should be actived;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
|
@ -589,29 +562,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
|
|||
need_first_active_streams_.push_back(stream_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
|
||||
if (IsTaskSink()) {
|
||||
ResetNew();
|
||||
ReorderIndependentOrders(graph_ptr);
|
||||
AssignAllNodesStream(graph_ptr);
|
||||
FindAllReduceParallel(graph_ptr);
|
||||
InsertActiveNew(graph_ptr);
|
||||
InsertSendRecvForHcomParallel(graph_ptr);
|
||||
InsertSendRecvForIndependent(graph_ptr);
|
||||
UpdateStreamId(graph_ptr);
|
||||
UpdateEventId(graph_ptr);
|
||||
GetNeedActiveStreams(graph_ptr);
|
||||
// 2)first stream 0 should be actived first;
|
||||
need_first_active_streams_.emplace_back(0);
|
||||
|
||||
MS_LOG(INFO) << "after finish stream assign";
|
||||
graph_ptr->PrintGraphExecuteOrder();
|
||||
|
||||
// Get info for D Model
|
||||
generator::IRModelUtil::GetInstance().set_event_num(total_event_num());
|
||||
generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num());
|
||||
// Init to 1,temporarily
|
||||
generator::IRModelUtil::GetInstance().set_batch_num(1);
|
||||
// 3)independent stream:if has not been activate, push to need active vector
|
||||
if (!independent_stream_activated_) {
|
||||
for (auto &item : independent_stream_map_) {
|
||||
need_first_active_streams_.emplace_back(item.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -722,33 +681,6 @@ void AscendStreamAssign::InsertSendRecvForIndependent(const shared_ptr<session::
|
|||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||
if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) {
|
||||
MS_LOG(INFO) << "GetNext should not be independent node";
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr);
|
||||
if (input_nums == 0) {
|
||||
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto inputs = node_ptr->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (!inputs[i]->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsTaskSink() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -762,56 +694,54 @@ bool AscendStreamAssign::IsTaskSink() {
|
|||
}
|
||||
|
||||
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
|
||||
if (total_common_stream_num_ == 0) {
|
||||
MS_EXCEPTION_IF_NULL(wait_active_stream_list);
|
||||
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
|
||||
uint32_t total_stream_num = stream_manager.GetCurAllocStreamNum();
|
||||
if (total_stream_num == 0) {
|
||||
MS_LOG(INFO) << "total_common_stream_num is zero";
|
||||
return;
|
||||
}
|
||||
|
||||
// common stream:active first common stream
|
||||
MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]";
|
||||
for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) {
|
||||
for (uint32_t i = 0; i < total_stream_num; i++) {
|
||||
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i);
|
||||
if (it == need_first_active_streams_.end()) {
|
||||
MS_LOG(INFO) << "wait common stream id = " << i;
|
||||
(*wait_active_stream_list).push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// all independ stream id before first physical stream id should be actived
|
||||
auto it = logic_to_independent_map_.find(first_logic_id_);
|
||||
if (it != logic_to_independent_map_.end()) {
|
||||
uint32_t independent_id = it->second;
|
||||
auto res = std::find(independent_before_physic_id_.begin(), independent_before_physic_id_.end(), independent_id);
|
||||
if (res == independent_before_physic_id_.end()) {
|
||||
// first physical to independ id may be not in independent_before_physic_id_
|
||||
independent_before_physic_id_.push_back(independent_id);
|
||||
}
|
||||
MS_LOG(INFO) << "active independent id[" << independent_id << "]";
|
||||
bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) {
|
||||
MS_EXCEPTION_IF_NULL(apply_kernel);
|
||||
return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL;
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsFusionHcom(const CNodePtr &cur_cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
bool is_hcom = IsHcom(cur_cnode_ptr);
|
||||
if (!is_hcom) {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t max_before_physic = 0;
|
||||
for (size_t i = 0; i < independent_before_physic_id_.size(); i++) {
|
||||
if (independent_before_physic_id_[i] > max_before_physic) {
|
||||
max_before_physic = independent_before_physic_id_[i];
|
||||
}
|
||||
MS_LOG(INFO) << "independent id[" << independent_before_physic_id_[i] << "] before first physic is active";
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrFusion, cur_cnode_ptr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < total_independ_stream_num_; i++) {
|
||||
if (i + total_common_stream_num_ <= max_before_physic) {
|
||||
continue;
|
||||
}
|
||||
// all wait streams should not in need_first_active_streams_
|
||||
auto iter =
|
||||
std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_);
|
||||
if (iter == need_first_active_streams_.end()) {
|
||||
MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_;
|
||||
(*wait_active_stream_list).push_back(i + total_common_stream_num_);
|
||||
}
|
||||
if (AnfAlgo::GetNodeAttr<int>(cur_cnode_ptr, kAttrFusion) == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
|
||||
MS_EXCEPTION_IF_NULL(streams);
|
||||
for (const auto &stream : hcom_stream_list_) {
|
||||
(*streams).emplace_back(stream);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
|
||||
void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
|
@ -829,24 +759,19 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
|||
others.emplace_back(cur_cnode_ptr);
|
||||
}
|
||||
}
|
||||
if (others.empty()) {
|
||||
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
if (others.empty() || independents.empty()) {
|
||||
MS_LOG(INFO) << "independent or others is empty, no need reorder";
|
||||
return;
|
||||
}
|
||||
if (independents.empty()) {
|
||||
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
return;
|
||||
}
|
||||
std::vector<CNodePtr> processed;
|
||||
|
||||
std::set<CNode *> processed;
|
||||
for (size_t i = 0; i < others.size(); i++) {
|
||||
auto begin = others.begin() + i;
|
||||
auto end = begin + 1;
|
||||
bool flag = false;
|
||||
for (size_t j = 0; j < independents.size(); j++) {
|
||||
auto cur_independent = independents[j];
|
||||
auto it = std::find(processed.begin(), processed.end(), cur_independent);
|
||||
auto it = std::find(processed.begin(), processed.end(), cur_independent.get());
|
||||
if (it != processed.end()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -855,7 +780,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
|||
flag = true;
|
||||
exe_orders.emplace_back(cur_independent);
|
||||
exe_orders.emplace_back(*begin);
|
||||
processed.emplace_back(cur_independent);
|
||||
processed.emplace(cur_independent.get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -864,8 +789,23 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
|
|||
}
|
||||
}
|
||||
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
|
||||
if (processed.size() != independents.size()) {
|
||||
MS_LOG(WARNING) << "processed independent nodes size is not equal to exiting independent nodes size";
|
||||
return;
|
||||
}
|
||||
|
||||
graph_ptr->set_execution_order(exe_orders);
|
||||
}
|
||||
|
||||
void AscendStreamAssign::Reset() {
|
||||
total_event_num_ = 0;
|
||||
independent_stream_activated_ = false;
|
||||
independent_stream_map_.clear();
|
||||
processed_streams_.clear();
|
||||
hcom_stream_list_.clear();
|
||||
need_first_active_streams_.clear();
|
||||
inner_parallel_streams_.clear();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -36,6 +38,36 @@ using std::shared_ptr;
|
|||
using std::unordered_map;
|
||||
using std::unordered_set;
|
||||
using std::vector;
|
||||
using CnodeKey = void *;
|
||||
const uint32_t kInvalidStreamId = UINT32_MAX;
|
||||
class AscendStreamMng {
|
||||
public:
|
||||
static AscendStreamMng &GetInstance() {
|
||||
static AscendStreamMng instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
cur_stream_id = 0;
|
||||
cur_stream_num = 0;
|
||||
}
|
||||
uint32_t ApplyNewStream() {
|
||||
if (!cur_stream_num) {
|
||||
cur_stream_num++;
|
||||
return cur_stream_id;
|
||||
}
|
||||
cur_stream_num++;
|
||||
cur_stream_id++;
|
||||
return cur_stream_id;
|
||||
}
|
||||
|
||||
uint32_t GetCurAllocStream() { return cur_stream_id; }
|
||||
uint32_t GetCurAllocStreamNum() { return cur_stream_num; }
|
||||
|
||||
private:
|
||||
uint32_t cur_stream_num{0};
|
||||
uint32_t cur_stream_id{0};
|
||||
};
|
||||
|
||||
class AscendStreamAssign {
|
||||
public:
|
||||
|
@ -47,22 +79,11 @@ class AscendStreamAssign {
|
|||
AscendStreamAssign(const AscendStreamAssign &) = delete;
|
||||
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
|
||||
|
||||
uint32_t GetTotalStreamNum() const;
|
||||
// new stream policy
|
||||
uint32_t total_common_stream_num() const { return total_common_stream_num_; }
|
||||
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
|
||||
uint32_t total_event_num() const { return total_event_num_; }
|
||||
void GetHcomStreams(std::vector<uint32_t> *streams);
|
||||
|
||||
void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void ResetNew();
|
||||
void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
bool IsIndependentNode(const CNodePtr &node_ptr);
|
||||
const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
|
||||
const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
|
||||
const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
|
||||
void AssignStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
|
||||
const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
|
||||
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
|
||||
uint32_t stream_id);
|
||||
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
|
||||
|
@ -71,49 +92,41 @@ class AscendStreamAssign {
|
|||
private:
|
||||
AscendStreamAssign() = default;
|
||||
~AscendStreamAssign() = default;
|
||||
|
||||
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
||||
const CNodePtr &node);
|
||||
|
||||
bool IsHcom(const CNodePtr &apply_kernel);
|
||||
bool IsProcessed(uint32_t logic_id);
|
||||
void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
|
||||
void Reset();
|
||||
void CheckStreamAssign(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
|
||||
uint32_t *cur_stream_id);
|
||||
void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
|
||||
void UpdateStreamActive(const CNodePtr &active_ptr);
|
||||
void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
|
||||
bool IsTaskSink();
|
||||
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
|
||||
void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
|
||||
uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
|
||||
void SetCommonStreamNum(uint32_t cur_stream_id);
|
||||
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
bool IsProcessedParallelStream(uint32_t stream_id);
|
||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr);
|
||||
void UpdateAtomicAddrCleanStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void FindHcomParallelStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertStreamActive(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void UpdateStreamSwitch(const std::shared_ptr<session::KernelGraph> &graph_ptr, const CNodePtr &switch_ptr,
|
||||
const vector<uint32_t> &independent_stream, vector<CNodePtr> *orders);
|
||||
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
|
||||
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
|
||||
|
||||
uint32_t total_common_stream_num_{0};
|
||||
uint32_t total_independ_stream_num_{0};
|
||||
uint32_t total_event_num_{0};
|
||||
bool IsTaskSink();
|
||||
bool IsFusionHcom(const CNodePtr &cur_cnode_ptr);
|
||||
bool IsHcom(const CNodePtr &cur_cnode_ptr);
|
||||
bool IsIndependentNode(const CNodePtr &node_ptr);
|
||||
bool IsProcessedStream(uint32_t stream_id);
|
||||
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
||||
const CNodePtr &node);
|
||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||
|
||||
uint32_t first_physic_id_{UINT32_MAX};
|
||||
uint32_t first_logic_id_{UINT32_MAX};
|
||||
uint32_t independent_id_{UINT32_MAX};
|
||||
vector<uint32_t> processed_logic_id_{};
|
||||
std::unordered_map<uint32_t, uint32_t> logic_to_physic_map_{}; // key:logic id, value: first physic id
|
||||
std::unordered_map<uint32_t, uint32_t> logic_to_independent_map_{}; // key:logic id, value: dependent id
|
||||
std::vector<uint32_t> independent_before_physic_id_{}; // record independent id before first physic id
|
||||
std::vector<std::vector<uint32_t>> inner_parallel_streams_{};
|
||||
std::vector<uint32_t> processed_parallel_streams_{};
|
||||
std::vector<uint32_t> hcom_stream_list_{};
|
||||
uint32_t total_event_num_{0};
|
||||
bool independent_stream_activated_{false};
|
||||
std::map<uint32_t, uint32_t> independent_stream_map_{};
|
||||
std::set<uint32_t> processed_streams_{};
|
||||
std::set<uint32_t> hcom_stream_list_{};
|
||||
std::vector<uint32_t> need_first_active_streams_{};
|
||||
std::vector<std::vector<uint32_t>> inner_parallel_streams_{};
|
||||
|
||||
// new policy end
|
||||
};
|
||||
} // namespace ascend
|
||||
|
|
|
@ -37,24 +37,6 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
using device::ascend::ProfilingUtils;
|
||||
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph->execution_order();
|
||||
std::vector<CNodePtr> momentum_list;
|
||||
std::vector<CNodePtr> other_list;
|
||||
for (const auto &cnode : origin_cnode_list) {
|
||||
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) {
|
||||
momentum_list.emplace_back(cnode);
|
||||
} else {
|
||||
other_list.emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
std::vector<CNodePtr> new_order_list;
|
||||
new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
|
||||
new_order_list.insert(new_order_list.end(), momentum_list.begin(), momentum_list.end());
|
||||
kernel_graph->set_execution_order(new_order_list);
|
||||
}
|
||||
|
||||
void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order();
|
||||
|
@ -80,23 +62,6 @@ bool KernelAdjust::NeedInsertSwitch() {
|
|||
ConfigManager::GetInstance().iter_num() > 1);
|
||||
}
|
||||
|
||||
uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
auto cnode_ptr_list = kernel_graph_ptr->execution_order();
|
||||
CNodePtr cur_cnode_ptr = nullptr;
|
||||
uint32_t label = kInvalidDistincLabel;
|
||||
for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
||||
cur_cnode_ptr = cnode_ptr_list[i];
|
||||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
|
||||
label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return label;
|
||||
}
|
||||
|
||||
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
|
||||
uint32_t event_id) {
|
||||
MS_EXCEPTION_IF_NULL(graph_ptr);
|
||||
|
@ -138,6 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern
|
|||
}
|
||||
|
||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance();
|
||||
stream_manager.Reset();
|
||||
if (!NeedInsertSwitch()) {
|
||||
return;
|
||||
}
|
||||
|
@ -166,68 +133,62 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
if (orders.empty()) {
|
||||
MS_LOG(EXCEPTION) << "graph execution order is empty";
|
||||
}
|
||||
uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get());
|
||||
|
||||
std::vector<CNodePtr> exec_order;
|
||||
CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(first_stream_switch_app);
|
||||
AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(kGetNextLabel), first_stream_switch_app);
|
||||
|
||||
CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(second_stream_switch_app);
|
||||
AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_cnode_stream_label), second_stream_switch_app);
|
||||
// add attr "stream_need_active"
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), second_stream_switch_app);
|
||||
// getnext loop process
|
||||
// getnext loop stream switch op
|
||||
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(getnext_switch_app);
|
||||
uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
|
||||
exec_order.push_back(getnext_switch_app);
|
||||
|
||||
CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(first_stream_active_app);
|
||||
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get());
|
||||
std::vector<uint32_t> first_active_streams = {kFirstStreamSwitchLabel};
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(first_active_streams),
|
||||
first_stream_active_app);
|
||||
|
||||
CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(second_stream_active_app);
|
||||
// specific deal for common ctrl stream policy
|
||||
uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr);
|
||||
if (first_common_stream_switch_label == kInvalidDistincLabel) {
|
||||
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get());
|
||||
} else {
|
||||
AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> second_active_streams = {kSecondStreamSwitchLabel};
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(second_active_streams),
|
||||
second_stream_active_app);
|
||||
|
||||
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(assign_add_one);
|
||||
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get());
|
||||
|
||||
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId);
|
||||
AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get());
|
||||
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId);
|
||||
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get());
|
||||
|
||||
// reorder graph orders
|
||||
exec_order.push_back(first_stream_switch_app);
|
||||
// getnext op
|
||||
uint32_t getnext_stream_id = stream_manager.ApplyNewStream();
|
||||
size_t i = 0;
|
||||
for (; i < orders.size(); i++) {
|
||||
auto node = orders[i];
|
||||
exec_order.push_back(node);
|
||||
AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get());
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get());
|
||||
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// update getnext loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
|
||||
|
||||
// getnext loop send
|
||||
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId);
|
||||
AnfAlgo::SetStreamId(getnext_stream_id, send.get());
|
||||
exec_order.push_back(send);
|
||||
exec_order.push_back(second_stream_switch_app);
|
||||
|
||||
// fpbp loop process
|
||||
// fpbp loop stream switch
|
||||
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_switch_app);
|
||||
uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get());
|
||||
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
|
||||
exec_order.push_back(fpbp_switch_app);
|
||||
|
||||
// fpbp loop recv
|
||||
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId);
|
||||
uint32_t fpbp_stream_id = stream_manager.ApplyNewStream();
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, recv.get());
|
||||
exec_order.push_back(recv);
|
||||
|
||||
// update fpbp loop stream switch true_branch_stream attr
|
||||
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(fpbp_stream_id), fpbp_switch_app);
|
||||
|
||||
// fpbp loop AssignAdd
|
||||
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input);
|
||||
MS_EXCEPTION_IF_NULL(assign_add_one);
|
||||
AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
|
||||
exec_order.push_back(assign_add_one);
|
||||
|
||||
// fpbp memcpy
|
||||
std::vector<CNodePtr> memcpy_list;
|
||||
std::vector<CNodePtr> before_list;
|
||||
std::vector<CNodePtr> after_list;
|
||||
|
@ -244,12 +205,28 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
|
|||
before_list.emplace_back(cur_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
(void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order));
|
||||
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
|
||||
exec_order.push_back(first_stream_active_app);
|
||||
|
||||
// stream active to activate getnext loop
|
||||
CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(getnext_active_app);
|
||||
std::vector<uint32_t> getnext_active_streams = {getnext_switch_stream_id};
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
|
||||
getnext_active_app);
|
||||
exec_order.push_back(getnext_active_app);
|
||||
|
||||
// fpbp loop other ops
|
||||
(void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order));
|
||||
exec_order.push_back(second_stream_active_app);
|
||||
|
||||
// stream active to activate fpbp loop
|
||||
CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
|
||||
MS_EXCEPTION_IF_NULL(fpbp_active_app);
|
||||
// specific deal for common ctrl stream policy
|
||||
std::vector<uint32_t> fpbp_active_streams = {fpbp_switch_stream_id};
|
||||
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app);
|
||||
exec_order.push_back(fpbp_active_app);
|
||||
|
||||
kernel_graph_ptr->set_execution_order(exec_order);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,9 +39,9 @@ constexpr auto kZeroParamName = "zero";
|
|||
constexpr auto kOneParamName = "one";
|
||||
constexpr auto kStreamNeedActivedFirst = "stream_need_active_first";
|
||||
|
||||
const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1;
|
||||
const uint32_t kGetNextLabel = kInvalidDistincLabel - 2;
|
||||
const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3;
|
||||
const uint32_t kFirstStreamSwitchLabel = 0;
|
||||
const uint32_t kGetNextLabel = 1;
|
||||
const uint32_t kSecondStreamSwitchLabel = 2;
|
||||
const uint32_t kInvalidEventId = UINT32_MAX;
|
||||
const uint32_t kFirstEventId = kInvalidEventId / 2;
|
||||
namespace device {
|
||||
|
@ -51,7 +51,7 @@ class KernelAdjust {
|
|||
static KernelAdjust instance;
|
||||
return instance;
|
||||
}
|
||||
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
||||
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
bool StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);
|
||||
|
@ -65,7 +65,6 @@ class KernelAdjust {
|
|||
void ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
|
||||
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
|
||||
uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
|
||||
void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
std::map<std::string, mindspore::ParameterPtr> *switch_loop_input);
|
||||
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
|
|
|
@ -35,7 +35,7 @@ class KernelInfo {
|
|||
select_kernel_build_info_ = nullptr;
|
||||
output_address_list_ = {};
|
||||
workspace_address_list_ = {};
|
||||
stream_id_ = 0;
|
||||
stream_id_ = UINT32_MAX;
|
||||
stream_distinction_label_ = kInvalidDistincLabel;
|
||||
graph_id_ = kInvalidGraphId;
|
||||
}
|
||||
|
|
|
@ -283,18 +283,37 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
auto graph_inputs = graph->inputs();
|
||||
auto graph_valid_input = graph->valid_inputs();
|
||||
for (size_t i = 0; i < graph_inputs.size(); i++) {
|
||||
std::vector<AnfNodePtr> need_alloc_nodes;
|
||||
for (size_t i = 0; i < graph_inputs.size(); ++i) {
|
||||
auto item = graph_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (!item->isa<Parameter>()) {
|
||||
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
|
||||
continue;
|
||||
}
|
||||
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
|
||||
auto outs = AnfAlgo::GetAllOutput(item);
|
||||
for (auto &out : outs) {
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
if (!out->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
if (NodeOutputDeviceAddressExist(out, 0)) {
|
||||
continue;
|
||||
}
|
||||
need_alloc_nodes.push_back(out);
|
||||
}
|
||||
}
|
||||
if (!item->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
if (NodeOutputDeviceAddressExist(item, 0)) {
|
||||
continue;
|
||||
}
|
||||
need_alloc_nodes.push_back(item);
|
||||
}
|
||||
|
||||
for (auto &item : need_alloc_nodes) {
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(item);
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
||||
|
|
|
@ -75,7 +75,6 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
|
|||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
|
||||
|
||||
vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT};
|
||||
vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool};
|
||||
if (input_format.size() != input_type.size()) {
|
||||
|
|
|
@ -282,8 +282,12 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa
|
|||
|
||||
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
||||
|
||||
static bool IsCtrlSink() {
|
||||
static bool IsCtrlSink(const FuncGraphPtr &graph) {
|
||||
auto ms_ctx = MsContext::GetInstance();
|
||||
if (ms_ctx->execution_mode() != kGraphMode) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string device_target = ms_ctx->device_target();
|
||||
if (device_target != kAscendDevice) {
|
||||
return false;
|
||||
|
@ -293,12 +297,7 @@ static bool IsCtrlSink() {
|
|||
return false;
|
||||
}
|
||||
|
||||
const char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
|
||||
if (enable_ctrl_sink == nullptr) {
|
||||
return false;
|
||||
}
|
||||
std::string enable_ctrl_sink_str(enable_ctrl_sink);
|
||||
if (enable_ctrl_sink_str == "0") {
|
||||
if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -311,7 +310,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
}
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
||||
if (IsCtrlSink()) {
|
||||
if (IsCtrlSink(func_graph)) {
|
||||
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
||||
return true;
|
||||
}
|
||||
|
@ -323,7 +322,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (compile->ContainMixedTarget(func_graph)) {
|
||||
if (CompileGraphs::ContainMixedTarget(func_graph)) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
context_ptr->set_loop_sink_flag(false);
|
||||
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
|
@ -341,7 +340,7 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
||||
if (IsCtrlSink()) {
|
||||
if (IsCtrlSink(nullptr)) {
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
|
|
@ -996,5 +996,23 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) {
|
|||
}
|
||||
return shape.size() == kShape1dDims && shape[0] == 1;
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
|
||||
std::vector<CNodePtr> all_opt_list;
|
||||
std::vector<CNodePtr> non_opt_list;
|
||||
|
||||
for (const auto &node : *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
|
||||
all_opt_list.emplace_back(node);
|
||||
} else {
|
||||
non_opt_list.emplace_back(node);
|
||||
}
|
||||
}
|
||||
node_list->clear();
|
||||
std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
|
||||
std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||
}
|
||||
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -189,6 +189,7 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
|
||||
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
|
||||
static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -40,7 +40,7 @@ static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFind
|
|||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
|
@ -67,7 +67,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
|
|||
}
|
||||
memo->insert(kg.get());
|
||||
|
||||
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
for (auto &arg : iter.second) {
|
||||
|
@ -178,16 +178,18 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
for (auto &iter : graph_id_map) {
|
||||
auto &kg = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
|
||||
for (auto &in : kg->inputs()) {
|
||||
auto it = real_inputs.find(in);
|
||||
if (it == real_inputs.end()) {
|
||||
continue;
|
||||
}
|
||||
auto ¶meter = it->first;
|
||||
auto &args = it->second;
|
||||
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
for (auto &it : real_inputs) {
|
||||
auto ¶meter = it.first;
|
||||
auto &args = it.second;
|
||||
for (auto &arg : args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (memo.find({parameter, arg}) != memo.end()) {
|
||||
continue;
|
||||
} else {
|
||||
memo.emplace(parameter, arg);
|
||||
}
|
||||
if (arg->isa<Parameter>()) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
|
@ -198,7 +200,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
if (target_graph_iter == graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
||||
}
|
||||
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
||||
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -290,17 +292,8 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg,
|
|||
|
||||
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
||||
const CNodePtr &last_label) {
|
||||
auto origin_return = kg->get_return();
|
||||
const std::vector<AnfNodePtr> &origin_return_inputs = origin_return->inputs();
|
||||
// if entry graph, replace return with make_tuple
|
||||
if (from_graph_call_node == nullptr || last_label == nullptr) {
|
||||
MS_LOG(INFO) << kg->ToString() << " is entry graph.";
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)};
|
||||
make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end());
|
||||
auto make_tuple = kg->NewCNode(make_tuple_inputs);
|
||||
origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple});
|
||||
} else {
|
||||
// else replace return with label_goto
|
||||
// if not entry graph, replace return with label_goto
|
||||
if (from_graph_call_node != nullptr && last_label != nullptr) {
|
||||
auto label_goto =
|
||||
kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label});
|
||||
MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString();
|
||||
|
@ -443,6 +436,20 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|||
return {partial_cnode, branch_kg};
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
|
||||
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
|
||||
MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]";
|
||||
if (from_outputs.size() != to_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size["
|
||||
<< to_outputs.size() << "]";
|
||||
}
|
||||
for (size_t i = 0; i < from_outputs.size(); i++) {
|
||||
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
|
||||
|
@ -472,7 +479,16 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
}
|
||||
memo->insert(graph.get());
|
||||
graph->SetExecOrderByDefault();
|
||||
const std::vector<CNodePtr> &cnodes = graph->execution_order();
|
||||
std::vector<CNodePtr> cnodes = graph->execution_order();
|
||||
|
||||
auto end_label_goto = graph->get_end_goto();
|
||||
if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) {
|
||||
cnodes.pop_back();
|
||||
}
|
||||
AnfAlgo::ReorderExecList(NOT_NULL(&cnodes));
|
||||
if (end_label_goto != nullptr) {
|
||||
cnodes.push_back(end_label_goto);
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> execution_order;
|
||||
uint32_t child_order_index = 0;
|
||||
|
|
|
@ -52,6 +52,7 @@ class AscendControlParser {
|
|||
const CNodePtr &last_label);
|
||||
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
|
||||
|
||||
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
|
||||
// root graph order
|
||||
|
|
|
@ -550,7 +550,6 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
|
||||
void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
device::KernelAdjust::GetInstance().Reorder(kernel_graph);
|
||||
opt::HideNopNode(kernel_graph.get());
|
||||
// Insert CLearZero op
|
||||
// prepare for next step from json get atomic info
|
||||
|
@ -583,7 +582,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel
|
|||
|
||||
void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
device::ascend::AscendStreamAssign::GetInstance().AssignStreamNew(kernel_graph);
|
||||
device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
|
@ -1539,6 +1538,11 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
|||
RecurseSplitGraph(root_graph, NOT_NULL(&memo));
|
||||
}
|
||||
memo.clear();
|
||||
// add maketuple to the end of the last child graph to suit old process
|
||||
auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back();
|
||||
auto make_tuple = output_graph->NewCNode(
|
||||
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), output_graph->output()});
|
||||
output_graph->set_output(make_tuple);
|
||||
// replace the real input if the real input is a call
|
||||
RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo));
|
||||
}
|
||||
|
|
|
@ -43,12 +43,28 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
|||
|
||||
std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
|
||||
return {item_with_index.first};
|
||||
AnfNodePtr node = item_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto outputs = AnfAlgo::GetAllOutput(node);
|
||||
std::set<AnfNodePtr> memo;
|
||||
std::vector<AnfNodePtr> new_output;
|
||||
for (auto &output : outputs) {
|
||||
if (memo.find(output) != memo.end()) {
|
||||
continue;
|
||||
}
|
||||
memo.insert(output);
|
||||
new_output.push_back(output);
|
||||
}
|
||||
if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
|
||||
node = new_output[0];
|
||||
}
|
||||
}
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
return {node};
|
||||
}
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>());
|
||||
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>());
|
||||
for (const auto &child_graph : child_graphs) {
|
||||
if (child_graph->get_output_null()) {
|
||||
continue;
|
||||
|
@ -623,18 +639,25 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
|||
(void)node_output_edges_.erase(old_anf_node);
|
||||
}
|
||||
// update graph inputs in child graph
|
||||
auto it_real_inputs = real_inputs_.find(old_anf_node);
|
||||
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
||||
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
||||
return n.first == old_anf_node.get();
|
||||
});
|
||||
if (it_real_inputs != real_inputs_.end()) {
|
||||
// erase old parameter in map
|
||||
auto old_args = it_real_inputs->second;
|
||||
real_inputs_.erase(it_real_inputs);
|
||||
// insert new parameter to map
|
||||
auto iter = real_inputs_.find(new_anf_node);
|
||||
auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
||||
[&new_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
||||
return n.first == new_anf_node.get();
|
||||
});
|
||||
if (iter != real_inputs_.end()) {
|
||||
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
|
||||
iter->second = it_real_inputs->second;
|
||||
iter->second = old_args;
|
||||
} else {
|
||||
real_inputs_[new_anf_node.get()] = it_real_inputs->second;
|
||||
real_inputs_.emplace_back(new_anf_node, old_args);
|
||||
}
|
||||
// erase old parameter in map
|
||||
real_inputs_.erase(old_anf_node);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -676,57 +699,33 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
|
|||
MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
||||
real_inputs_[parameter] = std::vector<AnfNodePtr>();
|
||||
}
|
||||
auto &args = real_inputs_[parameter];
|
||||
(void)args.push_back(arg);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto iter = real_inputs_.find(parameter);
|
||||
auto iter = std::find_if(
|
||||
real_inputs_.begin(), real_inputs_.end(),
|
||||
[¶meter](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { return n.first == parameter; });
|
||||
if (iter != real_inputs_.end()) {
|
||||
return iter->second;
|
||||
auto &args = iter->second;
|
||||
args.push_back(arg);
|
||||
} else {
|
||||
real_inputs_.emplace_back(parameter, std::vector<AnfNodePtr>(1, arg));
|
||||
}
|
||||
MS_LOG(EXCEPTION) << parameter->DebugString() << " not found.";
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateCallRealInput() {
|
||||
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
||||
std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_map;
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
|
||||
for (auto &it : real_inputs_) {
|
||||
auto parameter = it.first;
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto real_inputs = it.second;
|
||||
std::vector<AnfNodePtr> new_real_inputs;
|
||||
std::set<AnfNodePtr> erase_real_inputs;
|
||||
for (auto &real_input : real_inputs) {
|
||||
// if real input is a call node ,find the child graph output act as the new real input
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
|
||||
(void)erase_real_inputs.insert(item_with_index.first);
|
||||
new_real_inputs = GetCallRealOutputs(item_with_index.first);
|
||||
continue;
|
||||
}
|
||||
auto tmp_real_input = GetCallRealOutputs(item_with_index.first);
|
||||
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
|
||||
}
|
||||
for (auto &erase_node : erase_real_inputs) {
|
||||
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
|
||||
for (auto iter = real_inputs.begin(); iter != real_inputs.end();) {
|
||||
if (*iter == erase_node) {
|
||||
iter = real_inputs.erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &new_real_input : new_real_inputs) {
|
||||
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
||||
<< " insert real input:" << new_real_input->DebugString();
|
||||
(void)real_inputs.push_back(new_real_input);
|
||||
}
|
||||
real_inputs_map[parameter] = real_inputs;
|
||||
real_inputs_map.emplace_back(parameter, new_real_inputs);
|
||||
}
|
||||
real_inputs_ = real_inputs_map;
|
||||
}
|
||||
|
|
|
@ -127,8 +127,7 @@ class KernelGraph : public FuncGraph {
|
|||
// find anf node in graph
|
||||
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
|
||||
// get real inputs
|
||||
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs() const { return real_inputs_; }
|
||||
std::vector<AnfNodePtr> GetRealInput(const AnfNodePtr ¶meter);
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
|
||||
void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg);
|
||||
// used to dump ir
|
||||
std::string ToString() const override;
|
||||
|
@ -197,7 +196,7 @@ class KernelGraph : public FuncGraph {
|
|||
// parameter graph
|
||||
std::shared_ptr<KernelGraph> parent_graph_;
|
||||
// record real parameters,inputs_ is the formal parameters
|
||||
std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_;
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
|
||||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
|
|
@ -727,23 +727,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
|||
summary_callback_ = callback;
|
||||
}
|
||||
|
||||
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_list);
|
||||
std::vector<CNodePtr> all_opt_list;
|
||||
std::vector<CNodePtr> non_opt_list;
|
||||
|
||||
for (const auto &node : *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
|
||||
all_opt_list.emplace_back(node);
|
||||
} else {
|
||||
non_opt_list.emplace_back(node);
|
||||
}
|
||||
}
|
||||
node_list->clear();
|
||||
(void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||
}
|
||||
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
||||
|
||||
void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
|
|
|
@ -857,6 +857,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
|||
}
|
||||
|
||||
bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||
FuncGraphSet graphs = graph_manager->func_graphs();
|
||||
|
|
|
@ -124,7 +124,7 @@ class CompileGraphs {
|
|||
void Compile(const FuncGraphPtr &func_graph);
|
||||
FinalVMPtr Link(const FuncGraphPtr &func_graph);
|
||||
FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
|
||||
bool ContainMixedTarget(const FuncGraphPtr &graph);
|
||||
static bool ContainMixedTarget(const FuncGraphPtr &graph);
|
||||
|
||||
private:
|
||||
InstSet insts_;
|
||||
|
|
|
@ -26,12 +26,12 @@ void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph
|
|||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; }
|
||||
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; }
|
||||
|
||||
void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; }
|
||||
|
||||
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; }
|
||||
void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; }
|
||||
|
||||
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; }
|
||||
|
||||
void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) { return; }
|
||||
|
||||
namespace tasksink {
|
||||
bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::vector<TaskInfoPtr> *const task_info_list,
|
||||
uint32_t graph_id) {
|
||||
|
@ -39,7 +39,6 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
|
|||
}
|
||||
} // namespace tasksink
|
||||
} // namespace ascend
|
||||
void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; }
|
||||
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return true; }
|
||||
bool KernelAdjust::NeedInsertSwitch() { return true; }
|
||||
|
|
Loading…
Reference in New Issue