!18727 fix get allreduce fusion bug
Merge pull request !18727 from kisnwang/fix-hccl-allreduce-fusion-get-bug
This commit is contained in:
commit
a7ddd10af5
|
@ -149,9 +149,8 @@ void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t>
|
|||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
|
||||
} else {
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
||||
|
@ -169,9 +168,8 @@ void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &
|
|||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
|
||||
all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
|
||||
} else {
|
||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
||||
}
|
||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
||||
|
|
Loading…
Reference in New Issue