optimize exec order only for one hccl group

This commit is contained in:
kswang 2021-06-03 21:43:15 +08:00
parent 6aff341ca5
commit 1d55f72bea
2 changed files with 15 additions and 30 deletions

View File

@ -135,23 +135,6 @@ std::string GetNodeGroup(const AnfNodePtr &node) {
}
return "";
}
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) {
MS_EXCEPTION_IF_NULL(optimized_comm_group);
auto node_group = GetNodeGroup(node);
if (node_group.find(kSyncBnGroup) != string::npos) {
return false;
}
auto node_name = AnfAlgo::GetCNodeName(node);
auto iter = optimized_comm_group->find(node_name);
if (iter == optimized_comm_group->end()) {
(*optimized_comm_group)[node_name] = node_group;
return true;
} else if (iter->second == node_group) {
return true;
}
return false;
}
} // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
@ -188,7 +171,6 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
}
return;
}
// visit all reduce node first, then other nodes
std::vector<AnfNodePtr> active_nodes;
for (const auto &output_edge : it->second) {
@ -209,7 +191,9 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
(void)visited_nodes->insert(next_node);
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
EnqueueActiveNodes(next_node, visit_queue, visited_nodes);
} else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
visit_queue->push(next_node);
} else {
@ -217,10 +201,7 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
}
}
}
for (auto &active_node : active_nodes) {
MS_EXCEPTION_IF_NULL(active_node);
MS_LOG(DEBUG) << "Visit node:" << active_node->DebugString();
visit_queue->push(active_node);
}
}
@ -233,7 +214,7 @@ void KernelGraph::SetExecOrderByDefault() {
std::queue<AnfNodePtr> zero_input_nodes;
std::stack<AnfNodePtr> delay_comm_stack;
std::queue<AnfNodePtr> communication_descendants;
std::map<std::string, std::string> optimized_comm_group;
std::string optimized_comm_group;
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
// seed nodes first, then delay comm nodes
if (seed_nodes.empty()) {
@ -262,9 +243,13 @@ void KernelGraph::SetExecOrderByDefault() {
}
// delay execute comm ops that need optimize
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
bool optimize_comm = is_fused_comm;
if (optimize_comm) {
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group);
bool optimize_comm = false;
if (is_fused_comm && optimized_comm_group.empty()) {
auto node_group = GetNodeGroup(node);
if (node_group.find(kSyncBnGroup) == string::npos) {
optimized_comm_group = node_group;
optimize_comm = true;
}
}
if (optimize_comm) {
while (!delay_comm_stack.empty()) {

View File

@ -166,11 +166,11 @@ void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
} else {
all_reduce_fusion_split_indices_[group] = sizes;
all_reduce_fusion_split_sizes_[group] = sizes;
}
}