forked from mindspore-Ecosystem/mindspore
set dtype for allreduce fusion
This commit is contained in:
parent
4932854776
commit
8aa0450b8d
|
@ -87,7 +87,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|||
if (attr_fusion == nullptr) {
|
||||
return "";
|
||||
}
|
||||
int64_t fusion = GetValue<int64_t>(attr_fusion);
|
||||
auto fusion = GetValue<int64_t>(attr_fusion);
|
||||
if (fusion == 0) {
|
||||
return "";
|
||||
}
|
||||
|
@ -101,7 +101,8 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|||
if (attr_op != nullptr) {
|
||||
op = GetValue<std::string>(attr_op);
|
||||
}
|
||||
return group + op + std::to_string(fusion);
|
||||
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
|
||||
}
|
||||
|
||||
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
||||
|
@ -146,7 +147,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|||
}
|
||||
|
||||
size_t segments = 0;
|
||||
if (split_indices.size() != 0) {
|
||||
if (!split_indices.empty()) {
|
||||
uint32_t last_index = 0;
|
||||
for (size_t i = 0; i < split_indices.size(); ++i) {
|
||||
uint32_t index = split_indices[i];
|
||||
|
|
|
@ -142,11 +142,19 @@ void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_
|
|||
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
|
||||
}
|
||||
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
|
||||
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
|
||||
} else {
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
||||
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
||||
auto iter = all_reduce_fusion_split_indices_.find(group);
|
||||
if (iter != all_reduce_fusion_split_indices_.end()) {
|
||||
return iter->second;
|
||||
|
@ -154,11 +162,19 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(cons
|
|||
return {};
|
||||
}
|
||||
|
||||
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) {
|
||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
||||
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
|
||||
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
|
||||
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
|
||||
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
|
||||
} else {
|
||||
all_reduce_fusion_split_indices_[group] = sizes;
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
||||
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
||||
auto iter = all_reduce_fusion_split_sizes_.find(group);
|
||||
if (iter != all_reduce_fusion_split_sizes_.end()) {
|
||||
return iter->second;
|
||||
|
|
|
@ -100,10 +100,10 @@ class ParallelContext {
|
|||
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
|
||||
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
|
||||
|
||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
|
||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
|
||||
std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
|
||||
std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
|
||||
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
|
||||
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue