forked from mindspore-Ecosystem/mindspore
fix hcom parallel
This commit is contained in:
parent
77587bb04d
commit
5a67a6adb0
|
@ -104,7 +104,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
|
|||
continue;
|
||||
}
|
||||
|
||||
auto res = FindTargetOp(begin, end, cur_independent);
|
||||
auto res = FindTargetOp(begin, end, cur_independent, false);
|
||||
if (res != end) {
|
||||
flag = true;
|
||||
exe_orders.emplace_back(cur_independent);
|
||||
|
@ -247,10 +247,6 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|||
}
|
||||
group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
|
||||
}
|
||||
|
||||
for (const auto &item : group_hcom_graph_map_) {
|
||||
MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size();
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
||||
|
@ -787,7 +783,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
|||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||
|
||||
auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
|
||||
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true);
|
||||
if (target == cnodes.end()) {
|
||||
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
|
||||
<< ", can't find target for insert recv op, no insert send/recv";
|
||||
|
@ -795,11 +791,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
|
|||
continue;
|
||||
}
|
||||
|
||||
if (IsHcom(*target)) {
|
||||
it = cnodes.erase(it);
|
||||
continue;
|
||||
}
|
||||
|
||||
// deal recv op
|
||||
uint32_t stream_id = AnfAlgo::GetStreamId(*target);
|
||||
CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
|
||||
|
@ -834,15 +825,26 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|||
}
|
||||
|
||||
// get the input which located in the lastr exe orders
|
||||
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
|
||||
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode);
|
||||
if (it == cnodes.end()) {
|
||||
MS_LOG(ERROR) << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
|
||||
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes";
|
||||
} else {
|
||||
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
|
||||
if (inputs_cnode.empty()) {
|
||||
MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes";
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
|
||||
<< "; inputs cnode size:" << inputs_cnode.size();
|
||||
|
||||
for (size_t j = 0; j < inputs_cnode.size(); j++) {
|
||||
auto &cur_input = inputs_cnode.at(j);
|
||||
MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input);
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode);
|
||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id);
|
||||
auto pre_stream_id = AnfAlgo::GetStreamId(cur_input);
|
||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
|
||||
auto it = std::find(cnodes.begin(), cnodes.end(), cur_input);
|
||||
if (it == cnodes.end()) {
|
||||
MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
|
||||
<< " can't find input node:" << AnfAlgo::GetCNodeName(cur_input);
|
||||
}
|
||||
cnodes.insert(it + 1, send);
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
||||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
|
||||
|
@ -855,26 +857,56 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
|
||||
}
|
||||
|
||||
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const CNodePtr &cur_cnode_ptr) {
|
||||
vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const CNodePtr &cur_cnode_ptr) {
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup);
|
||||
auto input_cnodes = GetInputKernels(cur_cnode_ptr);
|
||||
if (input_cnodes.empty()) {
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
auto it_pos = cnode_ptr_list.begin();
|
||||
|
||||
for (auto &cnode : input_cnodes) {
|
||||
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode);
|
||||
if (it != cnode_ptr_list.end()) {
|
||||
it_pos = it;
|
||||
// record max index node for each stream
|
||||
std::map<uint32_t, std::pair<CNodePtr, uint32_t>> result;
|
||||
for (size_t i = 0; i < input_cnodes.size(); i++) {
|
||||
auto &cur_input = input_cnodes.at(i);
|
||||
auto stream_id = AnfAlgo::GetStreamId(cur_input);
|
||||
auto cur_index = GetIndexByKey(graph_ptr, cur_input.get());
|
||||
if (cur_index == UINT32_MAX) {
|
||||
MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph";
|
||||
}
|
||||
auto it = result.find(stream_id);
|
||||
if (it == result.end()) {
|
||||
result[stream_id] = std::make_pair(cur_input, cur_index);
|
||||
} else {
|
||||
auto max_index = it->second.second;
|
||||
if (cur_index > max_index) {
|
||||
result[stream_id] = std::make_pair(cur_input, cur_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) {
|
||||
MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found";
|
||||
|
||||
vector<CNodePtr> final_inputs;
|
||||
uint32_t max = 0;
|
||||
CNodePtr max_common_cnode = nullptr;
|
||||
for (const auto &item : result) {
|
||||
if (IsHcom(item.second.first)) {
|
||||
auto cur_group = AnfAlgo::GetNodeAttr<std::string>(item.second.first, kAttrGroup);
|
||||
if (cur_group == group_name) {
|
||||
continue;
|
||||
} else {
|
||||
final_inputs.emplace_back(item.second.first);
|
||||
}
|
||||
} else {
|
||||
if (item.second.second > max) {
|
||||
max_common_cnode = item.second.first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return *it_pos;
|
||||
if (max_common_cnode != nullptr) {
|
||||
final_inputs.emplace_back(max_common_cnode);
|
||||
}
|
||||
return final_inputs;
|
||||
}
|
||||
|
||||
vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &node) {
|
||||
|
@ -956,9 +988,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
// key:group id, key: stream id, value:hcom index
|
||||
std::map<std::string, std::map<uint32_t, vector<size_t>>> group_hcom_index;
|
||||
std::map<std::string, uint32_t> group_first_hcom_stream;
|
||||
std::map<std::string, uint32_t> group_last_hcom_stream;
|
||||
std::map<std::string, std::vector<std::pair<uint32_t, vector<size_t>>>> group_hcom_index;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
if (!IsHcom(cur_cnode)) {
|
||||
|
@ -969,67 +999,60 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|||
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
|
||||
}
|
||||
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
|
||||
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
||||
<< "; stream id:" << cur_stream_id;
|
||||
auto iter = group_hcom_index.find(group_name);
|
||||
if (iter == group_hcom_index.end()) {
|
||||
std::map<uint32_t, vector<size_t>> hcom_index;
|
||||
hcom_index[cur_stream_id] = {i};
|
||||
std::vector<std::pair<uint32_t, vector<size_t>>> hcom_index;
|
||||
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
||||
group_hcom_index[group_name] = hcom_index;
|
||||
} else {
|
||||
auto &hcom_index = iter->second;
|
||||
auto it = hcom_index.find(cur_stream_id);
|
||||
if (it == hcom_index.end()) {
|
||||
hcom_index[cur_stream_id] = {i};
|
||||
} else {
|
||||
it->second.emplace_back(i);
|
||||
bool exit = false;
|
||||
for (auto &item : hcom_index) {
|
||||
if (item.first == cur_stream_id) {
|
||||
item.second.emplace_back(i);
|
||||
exit = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!exit) {
|
||||
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// record first hcom stream id
|
||||
auto it = group_first_hcom_stream.find(group_name);
|
||||
if (it == group_first_hcom_stream.end()) {
|
||||
group_first_hcom_stream[group_name] = cur_stream_id;
|
||||
}
|
||||
|
||||
// record last hcom stream id
|
||||
it = group_last_hcom_stream.find(group_name);
|
||||
if (it != group_last_hcom_stream.end()) {
|
||||
it->second = cur_stream_id;
|
||||
} else {
|
||||
group_last_hcom_stream[group_name] = cur_stream_id;
|
||||
for (const auto &hcom_index : group_hcom_index) {
|
||||
MS_LOG(DEBUG) << "Group:" << hcom_index.first;
|
||||
for (const auto &item : hcom_index.second) {
|
||||
MS_LOG(DEBUG) << "stream id:" << item.first;
|
||||
for (const auto &index : item.second) {
|
||||
MS_LOG(DEBUG) << "hcom index:" << index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &hcom_index : group_hcom_index) {
|
||||
if (hcom_index.second.size() < 2) {
|
||||
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them";
|
||||
return;
|
||||
MS_LOG(INFO) << "Group:" << hcom_index.first
|
||||
<< "; different stream hcom size is less than 2, no need insert event between them";
|
||||
continue;
|
||||
}
|
||||
auto group_name = hcom_index.first;
|
||||
auto it = group_first_hcom_stream.find(group_name);
|
||||
if (it == group_first_hcom_stream.end()) {
|
||||
MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name;
|
||||
}
|
||||
auto first_hcom_stream = it->second;
|
||||
|
||||
it = group_last_hcom_stream.find(group_name);
|
||||
if (it == group_last_hcom_stream.end()) {
|
||||
MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name;
|
||||
}
|
||||
auto last_hcom_stream = it->second;
|
||||
InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream);
|
||||
InsertEventBetweenHcom(graph_ptr, hcom_index.second);
|
||||
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const map<uint32_t, vector<size_t>> &hcom_index,
|
||||
uint32_t first_hcom_stream, uint32_t last_hcom_stream) {
|
||||
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) {
|
||||
vector<CNodePtr> orders;
|
||||
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back();
|
||||
size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front();
|
||||
size_t first_stream_last_index = hcom_index[0].second.back();
|
||||
size_t last_stream_first_index = hcom_index.back().second.front();
|
||||
MS_LOG(INFO) << "First stream last index:" << first_stream_last_index
|
||||
<< "; last stream first index:" << last_stream_first_index;
|
||||
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders));
|
||||
for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
|
@ -1049,7 +1072,17 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
|
|||
orders.emplace_back(recv);
|
||||
orders.emplace_back(cur_cnode);
|
||||
} else {
|
||||
auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size();
|
||||
size_t cur_stream_hcom_size = UINT32_MAX;
|
||||
size_t first_index = UINT32_MAX;
|
||||
size_t last_index = UINT32_MAX;
|
||||
for (const auto &item : hcom_index) {
|
||||
if (item.first == cur_hcom_stream_id) {
|
||||
cur_stream_hcom_size = item.second.size();
|
||||
first_index = item.second.front();
|
||||
last_index = item.second.back();
|
||||
}
|
||||
}
|
||||
|
||||
if (cur_stream_hcom_size == 1) {
|
||||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(recv);
|
||||
|
@ -1059,12 +1092,12 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
|
|||
orders.emplace_back(send);
|
||||
} else {
|
||||
// current stream, first hcom:add recv op
|
||||
if (i == hcom_index.at(cur_hcom_stream_id).front()) {
|
||||
if (i == first_index) {
|
||||
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
orders.emplace_back(recv);
|
||||
cur_event_id = resource_manager.ApplyNewEvent();
|
||||
orders.emplace_back(cur_cnode);
|
||||
} else if (i == hcom_index.at(cur_hcom_stream_id).back()) {
|
||||
} else if (i == last_index) {
|
||||
// current stream, last hcom:add send op
|
||||
orders.emplace_back(cur_cnode);
|
||||
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
|
||||
|
@ -1080,19 +1113,19 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
|
|||
graph_ptr->set_execution_order(orders);
|
||||
}
|
||||
|
||||
bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr,
|
||||
size_t index) {
|
||||
bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index,
|
||||
const CNodePtr &node_ptr, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr);
|
||||
auto it = hcom_index.find(cur_hcom_stream_id);
|
||||
if (it == hcom_index.end()) {
|
||||
return false;
|
||||
for (const auto &item : hcom_index) {
|
||||
if (item.first == cur_hcom_stream_id) {
|
||||
auto it = std::find(item.second.begin(), item.second.end(), index);
|
||||
if (it != item.second.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index);
|
||||
if (iter == hcom_index.at(cur_hcom_stream_id).end()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// section6
|
||||
|
@ -1110,7 +1143,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
|||
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
|
||||
it = cnodes.insert(it + 1, send_cnode_ptr);
|
||||
|
||||
auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
|
||||
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false);
|
||||
if (target == cnodes.end()) {
|
||||
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
|
||||
<< "] can't find target for insert recv op, no insert send/recv";
|
||||
|
@ -1441,7 +1474,8 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr>
|
|||
}
|
||||
|
||||
vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
|
||||
vector<CNodePtr>::iterator end, const CNodePtr &node) {
|
||||
vector<CNodePtr>::iterator end, const CNodePtr &node,
|
||||
bool exclude_hcom) {
|
||||
while (begin != end) {
|
||||
auto inputs = (*begin)->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
|
@ -1451,16 +1485,22 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
|
|||
auto new_inputs = cnode->inputs();
|
||||
for (size_t j = 1; j < new_inputs.size(); j++) {
|
||||
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
|
||||
// find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function
|
||||
// only insert one time
|
||||
if (node == new_real_input.first) {
|
||||
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
|
||||
return begin;
|
||||
if (!(exclude_hcom && IsHcom(*begin))) {
|
||||
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
|
||||
return begin;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto real_input = AnfAlgo::VisitKernel(input, 0);
|
||||
if (node == real_input.first) {
|
||||
MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]";
|
||||
return begin;
|
||||
if (!(exclude_hcom && IsHcom(*begin))) {
|
||||
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
|
||||
return begin;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include "runtime/base.h"
|
||||
#include "runtime/rt_model.h"
|
||||
#include "runtime/stream.h"
|
||||
|
@ -149,12 +150,13 @@ class AscendStreamAssign {
|
|||
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
|
||||
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
|
||||
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index);
|
||||
|
||||
void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
CNodePtr GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
|
||||
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
|
||||
vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
|
||||
bool IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index, const CNodePtr &node_ptr,
|
||||
size_t index);
|
||||
|
||||
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
|
@ -169,7 +171,7 @@ class AscendStreamAssign {
|
|||
bool IsIndependentNode(const CNodePtr &node_ptr);
|
||||
bool IsProcessedStream(uint32_t stream_id);
|
||||
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
|
||||
const CNodePtr &node);
|
||||
const CNodePtr &node, bool exclude_hcom);
|
||||
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
|
||||
void SetLoopSink();
|
||||
|
||||
|
@ -199,6 +201,7 @@ class AscendStreamAssign {
|
|||
std::vector<uint32_t> need_first_active_streams_{};
|
||||
std::set<CNodeKey> independent_targets_;
|
||||
|
||||
// key:group name, value:key1:graph id, value1:stream id
|
||||
std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_;
|
||||
// key:graph id, value:stream set
|
||||
std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;
|
||||
|
|
Loading…
Reference in New Issue