diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 0506a997442..64c41bda516 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -87,7 +87,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { if (attr_fusion == nullptr) { return ""; } - int64_t fusion = GetValue(attr_fusion); + auto fusion = GetValue(attr_fusion); if (fusion == 0) { return ""; } @@ -101,7 +101,8 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { if (attr_op != nullptr) { op = GetValue(attr_op); } - return group + op + std::to_string(fusion); + auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return group + op + std::to_string(fusion) + TypeIdLabel(dtype); } void CheckInputs(const std::vector &fusion_inputs) { @@ -146,7 +147,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic } size_t segments = 0; - if (split_indices.size() != 0) { + if (!split_indices.empty()) { uint32_t last_index = 0; for (size_t i = 0; i < split_indices.size(); ++i) { uint32_t index = split_indices[i]; diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index b82701c2112..be7b07063db 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -142,11 +142,19 @@ void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_ optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save; } -void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { - all_reduce_fusion_split_indices_[group] = indices; +void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector &indices, const std::string &group) { + if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos && + group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos && + group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) { + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices; + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices; + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices; + } else { + all_reduce_fusion_split_indices_[group] = indices; + } } -const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { +std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { auto iter = all_reduce_fusion_split_indices_.find(group); if (iter != all_reduce_fusion_split_indices_.end()) { return iter->second; @@ -154,11 +162,19 @@ const std::vector ParallelContext::GetAllReduceFusionSplitIndices(cons return {}; } -void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { - all_reduce_fusion_split_sizes_[group] = sizes; +void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector &sizes, const std::string &group) { + if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos && + group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos && + group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) { + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes; + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes; + all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes; + } else { + all_reduce_fusion_split_indices_[group] = sizes; + } } -const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { +std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { auto iter = all_reduce_fusion_split_sizes_.find(group); if (iter != all_reduce_fusion_split_sizes_.end()) { return iter->second; diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index e1ce78e66d5..3b339229e52 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -100,10 +100,10 @@ class ParallelContext { void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save); bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; } - void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); - const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; - void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); - const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; + void SetAllReduceFusionSplitIndices(const std::vector &indices, const std::string &group); + std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; + void SetAllReduceFusionSplitSizes(const std::vector &sizes, const std::string &group); + std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { enable_all_reduce_fusion_ = enable_all_reduce_fusion; }