forked from mindspore-Ecosystem/mindspore
optimize exec order only for one hccl group
This commit is contained in:
parent
6aff341ca5
commit
1d55f72bea
|
@ -135,23 +135,6 @@ std::string GetNodeGroup(const AnfNodePtr &node) {
|
|||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::string> *optimized_comm_group) {
|
||||
MS_EXCEPTION_IF_NULL(optimized_comm_group);
|
||||
auto node_group = GetNodeGroup(node);
|
||||
if (node_group.find(kSyncBnGroup) != string::npos) {
|
||||
return false;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
auto iter = optimized_comm_group->find(node_name);
|
||||
if (iter == optimized_comm_group->end()) {
|
||||
(*optimized_comm_group)[node_name] = node_group;
|
||||
return true;
|
||||
} else if (iter->second == node_group) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
|
||||
|
@ -188,7 +171,6 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// visit all reduce node first, then other nodes
|
||||
std::vector<AnfNodePtr> active_nodes;
|
||||
for (const auto &output_edge : it->second) {
|
||||
|
@ -209,7 +191,9 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
|
|||
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
||||
(void)visited_nodes->insert(next_node);
|
||||
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
|
||||
if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
|
||||
EnqueueActiveNodes(next_node, visit_queue, visited_nodes);
|
||||
} else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
|
||||
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
|
||||
visit_queue->push(next_node);
|
||||
} else {
|
||||
|
@ -217,10 +201,7 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodeP
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &active_node : active_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(active_node);
|
||||
MS_LOG(DEBUG) << "Visit node:" << active_node->DebugString();
|
||||
visit_queue->push(active_node);
|
||||
}
|
||||
}
|
||||
|
@ -233,7 +214,7 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
std::queue<AnfNodePtr> zero_input_nodes;
|
||||
std::stack<AnfNodePtr> delay_comm_stack;
|
||||
std::queue<AnfNodePtr> communication_descendants;
|
||||
std::map<std::string, std::string> optimized_comm_group;
|
||||
std::string optimized_comm_group;
|
||||
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
|
||||
// seed nodes first, then delay comm nodes
|
||||
if (seed_nodes.empty()) {
|
||||
|
@ -262,9 +243,13 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
}
|
||||
// delay execute comm ops that need optimize
|
||||
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
|
||||
bool optimize_comm = is_fused_comm;
|
||||
if (optimize_comm) {
|
||||
optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group);
|
||||
bool optimize_comm = false;
|
||||
if (is_fused_comm && optimized_comm_group.empty()) {
|
||||
auto node_group = GetNodeGroup(node);
|
||||
if (node_group.find(kSyncBnGroup) == string::npos) {
|
||||
optimized_comm_group = node_group;
|
||||
optimize_comm = true;
|
||||
}
|
||||
}
|
||||
if (optimize_comm) {
|
||||
while (!delay_comm_stack.empty()) {
|
||||
|
|
|
@ -166,11 +166,11 @@ void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &
|
|||
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
|
||||
} else {
|
||||
all_reduce_fusion_split_indices_[group] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue