set dtype for allreduce fusion

This commit is contained in:
kswang 2021-06-08 22:10:54 +08:00
parent 4932854776
commit 8aa0450b8d
3 changed files with 30 additions and 13 deletions

View File

@ -87,7 +87,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
if (attr_fusion == nullptr) {
return "";
}
int64_t fusion = GetValue<int64_t>(attr_fusion);
auto fusion = GetValue<int64_t>(attr_fusion);
if (fusion == 0) {
return "";
}
@ -101,7 +101,8 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
if (attr_op != nullptr) {
op = GetValue<std::string>(attr_op);
}
return group + op + std::to_string(fusion);
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
}
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
@ -146,7 +147,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
}
size_t segments = 0;
if (split_indices.size() != 0) {
if (!split_indices.empty()) {
uint32_t last_index = 0;
for (size_t i = 0; i < split_indices.size(); ++i) {
uint32_t index = split_indices[i];

View File

@ -142,11 +142,19 @@ void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
}
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
} else {
all_reduce_fusion_split_indices_[group] = indices;
}
}
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
auto iter = all_reduce_fusion_split_indices_.find(group);
if (iter != all_reduce_fusion_split_indices_.end()) {
return iter->second;
@ -154,11 +162,19 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(cons
return {};
}
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) {
all_reduce_fusion_split_sizes_[group] = sizes;
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
} else {
all_reduce_fusion_split_indices_[group] = sizes;
}
}
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
auto iter = all_reduce_fusion_split_sizes_.find(group);
if (iter != all_reduce_fusion_split_sizes_.end()) {
return iter->second;

View File

@ -100,10 +100,10 @@ class ParallelContext {
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
}