!18016 add dtype for allreduce fusion

Merge pull request !18016 from kisnwang/add-dtype-for-hcclopfusion
This commit is contained in:
i-robot 2021-06-09 15:16:25 +08:00 committed by Gitee
commit 6aff341ca5
3 changed files with 30 additions and 13 deletions

View File

@ -87,7 +87,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
if (attr_fusion == nullptr) { if (attr_fusion == nullptr) {
return ""; return "";
} }
int64_t fusion = GetValue<int64_t>(attr_fusion); auto fusion = GetValue<int64_t>(attr_fusion);
if (fusion == 0) { if (fusion == 0) {
return ""; return "";
} }
@ -101,7 +101,8 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
if (attr_op != nullptr) { if (attr_op != nullptr) {
op = GetValue<std::string>(attr_op); op = GetValue<std::string>(attr_op);
} }
return group + op + std::to_string(fusion); auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
} }
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) { void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
@ -146,7 +147,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
} }
size_t segments = 0; size_t segments = 0;
if (split_indices.size() != 0) { if (!split_indices.empty()) {
uint32_t last_index = 0; uint32_t last_index = 0;
for (size_t i = 0; i < split_indices.size(); ++i) { for (size_t i = 0; i < split_indices.size(); ++i) {
uint32_t index = split_indices[i]; uint32_t index = split_indices[i];

View File

@ -142,11 +142,19 @@ void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save; optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
} }
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices; if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
} else {
all_reduce_fusion_split_indices_[group] = indices;
}
} }
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
auto iter = all_reduce_fusion_split_indices_.find(group); auto iter = all_reduce_fusion_split_indices_.find(group);
if (iter != all_reduce_fusion_split_indices_.end()) { if (iter != all_reduce_fusion_split_indices_.end()) {
return iter->second; return iter->second;
@ -154,11 +162,19 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(cons
return {}; return {};
} }
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) { void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
all_reduce_fusion_split_sizes_[group] = sizes; if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
} else {
all_reduce_fusion_split_indices_[group] = sizes;
}
} }
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
auto iter = all_reduce_fusion_split_sizes_.find(group); auto iter = all_reduce_fusion_split_sizes_.find(group);
if (iter != all_reduce_fusion_split_sizes_.end()) { if (iter != all_reduce_fusion_split_sizes_.end()) {
return iter->second; return iter->second;

View File

@ -100,10 +100,10 @@ class ParallelContext {
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save); void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; } bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group); void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group); void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const; std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
enable_all_reduce_fusion_ = enable_all_reduce_fusion; enable_all_reduce_fusion_ = enable_all_reduce_fusion;
} }