!18727 fix get allreduce fusion bug

Merge pull request !18727 from kisnwang/fix-hccl-allreduce-fusion-get-bug
This commit is contained in:
i-robot 2021-06-23 02:12:08 +00:00 committed by Gitee
commit a7ddd10af5
1 changed files with 2 additions and 4 deletions

View File

@ -149,9 +149,8 @@ void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t>
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
} else {
all_reduce_fusion_split_indices_[group] = indices;
}
all_reduce_fusion_split_indices_[group] = indices;
}
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
@ -169,9 +168,8 @@ void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
} else {
all_reduce_fusion_split_sizes_[group] = sizes;
}
all_reduce_fusion_split_sizes_[group] = sizes;
}
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {