fix hcom parallel

This commit is contained in:
gukecai 2020-12-03 19:17:47 +08:00
parent 77587bb04d
commit 5a67a6adb0
2 changed files with 139 additions and 96 deletions

View File

@ -104,7 +104,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
continue;
}
auto res = FindTargetOp(begin, end, cur_independent);
auto res = FindTargetOp(begin, end, cur_independent, false);
if (res != end) {
flag = true;
exe_orders.emplace_back(cur_independent);
@ -247,10 +247,6 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
}
group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
}
for (const auto &item : group_hcom_graph_map_) {
MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size();
}
}
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
@ -787,7 +783,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
it = cnodes.insert(it + 1, send_cnode_ptr);
auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true);
if (target == cnodes.end()) {
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
<< ", can't find target for insert recv op, no insert send/recv";
@ -795,11 +791,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
continue;
}
if (IsHcom(*target)) {
it = cnodes.erase(it);
continue;
}
// deal recv op
uint32_t stream_id = AnfAlgo::GetStreamId(*target);
CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
@ -834,15 +825,26 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
}
// get the input which located in the lastr exe orders
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode);
if (it == cnodes.end()) {
MS_LOG(ERROR) << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes";
} else {
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
if (inputs_cnode.empty()) {
MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes";
continue;
}
MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
<< "; inputs cnode size:" << inputs_cnode.size();
for (size_t j = 0; j < inputs_cnode.size(); j++) {
auto &cur_input = inputs_cnode.at(j);
MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input);
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id);
auto pre_stream_id = AnfAlgo::GetStreamId(cur_input);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
auto it = std::find(cnodes.begin(), cnodes.end(), cur_input);
if (it == cnodes.end()) {
MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
<< " can't find input node:" << AnfAlgo::GetCNodeName(cur_input);
}
cnodes.insert(it + 1, send);
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
@ -855,26 +857,56 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
}
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
const CNodePtr &cur_cnode_ptr) {
vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
const CNodePtr &cur_cnode_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup);
auto input_cnodes = GetInputKernels(cur_cnode_ptr);
if (input_cnodes.empty()) {
return nullptr;
return {};
}
auto it_pos = cnode_ptr_list.begin();
for (auto &cnode : input_cnodes) {
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode);
if (it != cnode_ptr_list.end()) {
it_pos = it;
// record max index node for each stream
std::map<uint32_t, std::pair<CNodePtr, uint32_t>> result;
for (size_t i = 0; i < input_cnodes.size(); i++) {
auto &cur_input = input_cnodes.at(i);
auto stream_id = AnfAlgo::GetStreamId(cur_input);
auto cur_index = GetIndexByKey(graph_ptr, cur_input.get());
if (cur_index == UINT32_MAX) {
MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph";
}
auto it = result.find(stream_id);
if (it == result.end()) {
result[stream_id] = std::make_pair(cur_input, cur_index);
} else {
auto max_index = it->second.second;
if (cur_index > max_index) {
result[stream_id] = std::make_pair(cur_input, cur_index);
}
}
}
if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) {
MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found";
vector<CNodePtr> final_inputs;
uint32_t max = 0;
CNodePtr max_common_cnode = nullptr;
for (const auto &item : result) {
if (IsHcom(item.second.first)) {
auto cur_group = AnfAlgo::GetNodeAttr<std::string>(item.second.first, kAttrGroup);
if (cur_group == group_name) {
continue;
} else {
final_inputs.emplace_back(item.second.first);
}
} else {
if (item.second.second > max) {
max_common_cnode = item.second.first;
}
}
}
return *it_pos;
if (max_common_cnode != nullptr) {
final_inputs.emplace_back(max_common_cnode);
}
return final_inputs;
}
vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &node) {
@ -956,9 +988,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto cnode_ptr_list = graph_ptr->execution_order();
// key:group id, key: stream id, value:hcom index
std::map<std::string, std::map<uint32_t, vector<size_t>>> group_hcom_index;
std::map<std::string, uint32_t> group_first_hcom_stream;
std::map<std::string, uint32_t> group_last_hcom_stream;
std::map<std::string, std::vector<std::pair<uint32_t, vector<size_t>>>> group_hcom_index;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
@ -969,67 +999,60 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
}
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
<< "; stream id:" << cur_stream_id;
auto iter = group_hcom_index.find(group_name);
if (iter == group_hcom_index.end()) {
std::map<uint32_t, vector<size_t>> hcom_index;
hcom_index[cur_stream_id] = {i};
std::vector<std::pair<uint32_t, vector<size_t>>> hcom_index;
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
group_hcom_index[group_name] = hcom_index;
} else {
auto &hcom_index = iter->second;
auto it = hcom_index.find(cur_stream_id);
if (it == hcom_index.end()) {
hcom_index[cur_stream_id] = {i};
} else {
it->second.emplace_back(i);
bool exit = false;
for (auto &item : hcom_index) {
if (item.first == cur_stream_id) {
item.second.emplace_back(i);
exit = true;
break;
}
}
if (!exit) {
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
}
}
}
// record first hcom stream id
auto it = group_first_hcom_stream.find(group_name);
if (it == group_first_hcom_stream.end()) {
group_first_hcom_stream[group_name] = cur_stream_id;
}
// record last hcom stream id
it = group_last_hcom_stream.find(group_name);
if (it != group_last_hcom_stream.end()) {
it->second = cur_stream_id;
} else {
group_last_hcom_stream[group_name] = cur_stream_id;
for (const auto &hcom_index : group_hcom_index) {
MS_LOG(DEBUG) << "Group:" << hcom_index.first;
for (const auto &item : hcom_index.second) {
MS_LOG(DEBUG) << "stream id:" << item.first;
for (const auto &index : item.second) {
MS_LOG(DEBUG) << "hcom index:" << index;
}
}
}
for (const auto &hcom_index : group_hcom_index) {
if (hcom_index.second.size() < 2) {
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them";
return;
MS_LOG(INFO) << "Group:" << hcom_index.first
<< "; different stream hcom size is less than 2, no need insert event between them";
continue;
}
auto group_name = hcom_index.first;
auto it = group_first_hcom_stream.find(group_name);
if (it == group_first_hcom_stream.end()) {
MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name;
}
auto first_hcom_stream = it->second;
it = group_last_hcom_stream.find(group_name);
if (it == group_last_hcom_stream.end()) {
MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name;
}
auto last_hcom_stream = it->second;
InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream);
InsertEventBetweenHcom(graph_ptr, hcom_index.second);
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
}
}
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
const map<uint32_t, vector<size_t>> &hcom_index,
uint32_t first_hcom_stream, uint32_t last_hcom_stream) {
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) {
vector<CNodePtr> orders;
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto cnode_ptr_list = graph_ptr->execution_order();
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back();
size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front();
size_t first_stream_last_index = hcom_index[0].second.back();
size_t last_stream_first_index = hcom_index.back().second.front();
MS_LOG(INFO) << "First stream last index:" << first_stream_last_index
<< "; last stream first index:" << last_stream_first_index;
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders));
for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) {
auto cur_cnode = cnode_ptr_list[i];
@ -1049,7 +1072,17 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
orders.emplace_back(recv);
orders.emplace_back(cur_cnode);
} else {
auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size();
size_t cur_stream_hcom_size = UINT32_MAX;
size_t first_index = UINT32_MAX;
size_t last_index = UINT32_MAX;
for (const auto &item : hcom_index) {
if (item.first == cur_hcom_stream_id) {
cur_stream_hcom_size = item.second.size();
first_index = item.second.front();
last_index = item.second.back();
}
}
if (cur_stream_hcom_size == 1) {
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
@ -1059,12 +1092,12 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
orders.emplace_back(send);
} else {
// current stream, first hcom:add recv op
if (i == hcom_index.at(cur_hcom_stream_id).front()) {
if (i == first_index) {
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(recv);
cur_event_id = resource_manager.ApplyNewEvent();
orders.emplace_back(cur_cnode);
} else if (i == hcom_index.at(cur_hcom_stream_id).back()) {
} else if (i == last_index) {
// current stream, last hcom:add send op
orders.emplace_back(cur_cnode);
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
@ -1080,19 +1113,19 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g
graph_ptr->set_execution_order(orders);
}
bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr,
size_t index) {
bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index,
const CNodePtr &node_ptr, size_t index) {
MS_EXCEPTION_IF_NULL(node_ptr);
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr);
auto it = hcom_index.find(cur_hcom_stream_id);
if (it == hcom_index.end()) {
return false;
for (const auto &item : hcom_index) {
if (item.first == cur_hcom_stream_id) {
auto it = std::find(item.second.begin(), item.second.end(), index);
if (it != item.second.end()) {
return true;
}
}
}
auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index);
if (iter == hcom_index.at(cur_hcom_stream_id).end()) {
return false;
}
return true;
return false;
}
// section6
@ -1110,7 +1143,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
it = cnodes.insert(it + 1, send_cnode_ptr);
auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false);
if (target == cnodes.end()) {
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
<< "] can't find target for insert recv op, no insert send/recv";
@ -1441,7 +1474,8 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr>
}
vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
vector<CNodePtr>::iterator end, const CNodePtr &node) {
vector<CNodePtr>::iterator end, const CNodePtr &node,
bool exclude_hcom) {
while (begin != end) {
auto inputs = (*begin)->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
@ -1451,16 +1485,22 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
auto new_inputs = cnode->inputs();
for (size_t j = 1; j < new_inputs.size(); j++) {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
// find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function
// only insert one time
if (node == new_real_input.first) {
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
return begin;
if (!(exclude_hcom && IsHcom(*begin))) {
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
return begin;
}
}
}
} else {
auto real_input = AnfAlgo::VisitKernel(input, 0);
if (node == real_input.first) {
MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]";
return begin;
if (!(exclude_hcom && IsHcom(*begin))) {
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
return begin;
}
}
}
}

View File

@ -26,6 +26,7 @@
#include <vector>
#include <memory>
#include <unordered_set>
#include <utility>
#include "runtime/base.h"
#include "runtime/rt_model.h"
#include "runtime/stream.h"
@ -149,12 +150,13 @@ class AscendStreamAssign {
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index);
void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr);
CNodePtr GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
bool IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index, const CNodePtr &node_ptr,
size_t index);
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
@ -169,7 +171,7 @@ class AscendStreamAssign {
bool IsIndependentNode(const CNodePtr &node_ptr);
bool IsProcessedStream(uint32_t stream_id);
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr &node);
const CNodePtr &node, bool exclude_hcom);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void SetLoopSink();
@ -199,6 +201,7 @@ class AscendStreamAssign {
std::vector<uint32_t> need_first_active_streams_{};
std::set<CNodeKey> independent_targets_;
// key:group name, value:key1:graph id, value1:stream id
std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_;
// key:graph id, value:stream set
std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;