!9278 support hcom parallel by diff group

From: @gukecai
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2020-12-02 09:24:29 +08:00 committed by Gitee
commit cba50b13a5
2 changed files with 175 additions and 79 deletions

View File

@ -196,7 +196,7 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map;
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
// node has been assigned stream before
@ -205,27 +205,52 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
}
if (IsHcom(cur_cnode_ptr)) {
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode_ptr->DebugString() << " has no group attr";
}
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup);
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
auto it = graph_nodes_map.find(hcom_graph_id);
if (it == graph_nodes_map.end()) {
auto iter = group_graph_nodes_map.find(group_name);
if (iter == group_graph_nodes_map.end()) {
std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
group_graph_nodes_map[group_name] = graph_nodes_map;
} else {
it->second.emplace_back(cur_cnode_ptr);
auto &graph_nodes_map = iter->second;
auto it = graph_nodes_map.find(hcom_graph_id);
if (it == graph_nodes_map.end()) {
graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
} else {
it->second.emplace_back(cur_cnode_ptr);
}
}
}
}
MS_LOG(INFO) << "hcom diff graph id size:" << graph_nodes_map.size();
for (const auto &item : graph_nodes_map) {
bool new_graph = true;
auto graph_id = item.first;
hcom_graph_map_[graph_id] = {};
for (const auto &hcom_node_ptr : item.second) {
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
hcom_graph_map_[graph_id].emplace(assigned_stream_id);
new_graph = false;
}
MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size();
for (const auto &item : group_graph_nodes_map) {
MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size();
}
for (const auto &diff_group : group_graph_nodes_map) {
// group id:
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map;
for (const auto &item : diff_group.second) {
bool new_graph = true;
auto graph_id = item.first;
hcom_graph_map[graph_id] = {};
for (const auto &hcom_node_ptr : item.second) {
auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
hcom_graph_map[graph_id].emplace(assigned_stream_id);
new_graph = false;
}
}
group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
}
for (const auto &item : group_hcom_graph_map_) {
MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size();
}
MS_LOG(INFO) << "hcom stream nums : " << hcom_stream_map_.size();
}
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
@ -337,7 +362,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
}
void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
if (hcom_graph_map_.empty() && independent_graph_map_.empty()) {
if (group_hcom_graph_map_.empty() && independent_graph_map_.empty()) {
MS_LOG(INFO) << "Hcom and independent is empty";
return;
}
@ -347,19 +372,32 @@ void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraph
return;
}
MS_LOG(DEBUG) << "Hcom grpah map size:" << hcom_graph_map_.size();
std::map<uint32_t, std::set<uint32_t>> other_graph;
for (const auto &item : hcom_graph_map_) {
MS_LOG(INFO) << "Graph id:" << item.first;
if (item.first == root_graph_id) {
if (loop_sink_) {
ActiveRootGraphHcom(graph_ptr, item.second);
std::set<uint32_t> hcom_streams;
for (const auto &graph_nodes : group_hcom_graph_map_) {
for (const auto &item : graph_nodes.second) {
MS_LOG(INFO) << "Graph id:" << item.first;
if (item.first == root_graph_id) {
if (loop_sink_) {
hcom_streams.insert(item.second.begin(), item.second.end());
}
} else {
auto it = other_graph.find(item.first);
if (it == other_graph.end()) {
other_graph[item.first] = item.second;
} else {
for (const auto &stream : item.second) {
it->second.emplace(stream);
}
}
}
} else {
other_graph[item.first] = item.second;
}
}
if (!hcom_streams.empty()) {
ActiveRootGraphHcom(graph_ptr, hcom_streams);
}
MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size();
for (const auto &item : independent_graph_map_) {
MS_LOG(DEBUG) << "Graph id:" << item.first;
@ -505,7 +543,6 @@ void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr
independent_stream_activated_ = true;
graph_ptr->set_execution_order(update_cnode_list);
}
void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Start";
GetProcessedStream(graph_ptr);
@ -733,7 +770,7 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "Start";
InsertEventCommonDependHcom(graph_ptr);
InsertEventHcomDependCommon(graph_ptr);
InsertEventHcomDependCommonBak(graph_ptr);
InsertEventHcomDependHcom(graph_ptr);
MS_LOG(INFO) << "End";
}
@ -777,36 +814,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
}
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
const CNodePtr &cur_cnode_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
auto &inputs = cur_cnode_ptr->inputs();
auto it_pos = cnode_ptr_list.begin();
for (size_t i = 1; i < inputs.size(); i++) {
if (inputs[i]->isa<CNode>()) {
auto cnode = inputs[i]->cast<CNodePtr>();
while (opt::IsNopNode(cnode)) {
cnode = cnode->inputs()[1]->cast<CNodePtr>();
}
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode);
if (it != cnode_ptr_list.end()) {
it_pos = it;
}
} else {
continue;
}
}
if (it_pos == cnode_ptr_list.begin() && *it_pos != inputs[1]) {
MS_LOG(EXCEPTION) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found";
}
MS_LOG(INFO) << "The las input of node:" << cur_cnode_ptr->DebugString() << " is:" << (*it_pos)->fullname_with_scope()
<< "; name:" << (*it_pos)->DebugString();
return *it_pos;
}
// after memory reuse is correct, use this function
void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) {
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
@ -830,7 +837,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode);
if (it == cnodes.end()) {
MS_LOG(ERROR) << "hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
MS_LOG(ERROR) << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes";
} else {
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
@ -848,6 +855,58 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
}
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
const CNodePtr &cur_cnode_ptr) {
auto cnode_ptr_list = graph_ptr->execution_order();
auto input_cnodes = GetInputKernels(cur_cnode_ptr);
if (input_cnodes.empty()) {
return nullptr;
}
auto it_pos = cnode_ptr_list.begin();
for (auto &cnode : input_cnodes) {
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode);
if (it != cnode_ptr_list.end()) {
it_pos = it;
}
}
if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) {
MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found";
}
return *it_pos;
}
vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &node) {
vector<CNodePtr> input_cnodes;
queue<CNodePtr> nop_nodes;
auto inputs = node->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
auto real_input = AnfAlgo::VisitKernel(inputs[i], 0);
auto node = real_input.first;
if (opt::IsNopNode(node)) {
nop_nodes.push(node->cast<CNodePtr>());
while (!nop_nodes.empty()) {
auto cur_node = nop_nodes.front();
nop_nodes.pop();
auto new_inputs = cur_node->inputs();
for (size_t j = 1; j < new_inputs.size(); j++) {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
auto new_node = new_real_input.first;
if (opt::IsNopNode(new_node)) {
nop_nodes.push(new_node->cast<CNodePtr>());
} else if (new_node->isa<CNode>()) {
input_cnodes.emplace_back(new_node->cast<CNodePtr>());
}
}
}
} else if (node->isa<CNode>()) {
input_cnodes.emplace_back(node->cast<CNodePtr>());
}
}
return input_cnodes;
}
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto cnode_ptr_list = graph_ptr->execution_order();
@ -896,40 +955,70 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto cnode_ptr_list = graph_ptr->execution_order();
uint32_t first_hcom_stream = kInvalidStreamId;
uint32_t last_hcom_stream = kInvalidStreamId;
// key: stream id, value:hcom index
std::map<uint32_t, vector<size_t>> hcom_index;
// key:group id, key: stream id, value:hcom index
std::map<std::string, std::map<uint32_t, vector<size_t>>> group_hcom_index;
std::map<std::string, uint32_t> group_first_hcom_stream;
std::map<std::string, uint32_t> group_last_hcom_stream;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
continue;
}
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
auto it = hcom_index.find(cur_stream_id);
if (it != hcom_index.end()) {
hcom_index[cur_stream_id].emplace_back(i);
} else {
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) {
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
}
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
auto iter = group_hcom_index.find(group_name);
if (iter == group_hcom_index.end()) {
std::map<uint32_t, vector<size_t>> hcom_index;
hcom_index[cur_stream_id] = {i};
group_hcom_index[group_name] = hcom_index;
} else {
auto &hcom_index = iter->second;
auto it = hcom_index.find(cur_stream_id);
if (it == hcom_index.end()) {
hcom_index[cur_stream_id] = {i};
} else {
it->second.emplace_back(i);
}
}
// record first hcom stream id
if (first_hcom_stream == kInvalidStreamId) {
first_hcom_stream = cur_stream_id;
auto it = group_first_hcom_stream.find(group_name);
if (it == group_first_hcom_stream.end()) {
group_first_hcom_stream[group_name] = cur_stream_id;
}
// record last hcom stream id
if (cur_stream_id != last_hcom_stream) {
last_hcom_stream = cur_stream_id;
it = group_last_hcom_stream.find(group_name);
if (it != group_last_hcom_stream.end()) {
it->second = cur_stream_id;
} else {
group_last_hcom_stream[group_name] = cur_stream_id;
}
}
if (hcom_index.size() < 2) {
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them";
return;
for (const auto &hcom_index : group_hcom_index) {
if (hcom_index.second.size() < 2) {
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them";
return;
}
auto group_name = hcom_index.first;
auto it = group_first_hcom_stream.find(group_name);
if (it == group_first_hcom_stream.end()) {
MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name;
}
auto first_hcom_stream = it->second;
it = group_last_hcom_stream.find(group_name);
if (it == group_last_hcom_stream.end()) {
MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name;
}
auto last_hcom_stream = it->second;
InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream);
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
}
InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream);
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
}
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
@ -1199,9 +1288,12 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
// 3)hcom stream:if has not been activate, push to need active vector
if (!hcom_stream_activated_) {
auto it = hcom_graph_map_.find(root_graph_id);
if (it != hcom_graph_map_.end()) {
std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_));
for (const auto &item : group_hcom_graph_map_) {
auto &hcom_graph_map = item.second;
auto it = hcom_graph_map.find(root_graph_id);
if (it != hcom_graph_map.end()) {
std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_));
}
}
}
@ -1434,7 +1526,7 @@ void AscendStreamAssign::Reset() {
event_map_.clear();
independent_targets_.clear();
independent_graph_map_.clear();
hcom_graph_map_.clear();
group_hcom_graph_map_.clear();
middle_active_streams_.clear();
}

View File

@ -22,6 +22,7 @@
#include <map>
#include <set>
#include <string>
#include <queue>
#include <vector>
#include <memory>
#include <unordered_set>
@ -35,6 +36,7 @@ namespace mindspore {
namespace device {
namespace ascend {
using std::map;
using std::queue;
using std::shared_ptr;
using std::unordered_map;
using std::unordered_set;
@ -184,6 +186,7 @@ class AscendStreamAssign {
void PrintStreamGroups();
void FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr);
bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const;
vector<CNodePtr> GetInputKernels(const CNodePtr &node);
bool independent_stream_activated_{false};
bool hcom_stream_activated_{false};
@ -195,8 +198,9 @@ class AscendStreamAssign {
std::set<uint32_t> processed_streams_{};
std::vector<uint32_t> need_first_active_streams_{};
std::set<CNodeKey> independent_targets_;
std::map<std::string, std::map<uint32_t, std::set<uint32_t>>> group_hcom_graph_map_;
// key:graph id, value:stream set
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map_;
std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;
// attr for memory copy reuse