forked from mindspore-Ecosystem/mindspore
!18016 add dtype for allreduce fusion
Merge pull request !18016 from kisnwang/add-dtype-for-hcclopfusion
This commit is contained in:
commit
6aff341ca5
|
@ -87,7 +87,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
||||||
if (attr_fusion == nullptr) {
|
if (attr_fusion == nullptr) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
int64_t fusion = GetValue<int64_t>(attr_fusion);
|
auto fusion = GetValue<int64_t>(attr_fusion);
|
||||||
if (fusion == 0) {
|
if (fusion == 0) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -101,7 +101,8 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
||||||
if (attr_op != nullptr) {
|
if (attr_op != nullptr) {
|
||||||
op = GetValue<std::string>(attr_op);
|
op = GetValue<std::string>(attr_op);
|
||||||
}
|
}
|
||||||
return group + op + std::to_string(fusion);
|
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||||
|
return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
||||||
|
@ -146,7 +147,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t segments = 0;
|
size_t segments = 0;
|
||||||
if (split_indices.size() != 0) {
|
if (!split_indices.empty()) {
|
||||||
uint32_t last_index = 0;
|
uint32_t last_index = 0;
|
||||||
for (size_t i = 0; i < split_indices.size(); ++i) {
|
for (size_t i = 0; i < split_indices.size(); ++i) {
|
||||||
uint32_t index = split_indices[i];
|
uint32_t index = split_indices[i];
|
||||||
|
|
|
@ -142,11 +142,19 @@ void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_
|
||||||
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
|
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group) {
|
||||||
all_reduce_fusion_split_indices_[group] = indices;
|
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
|
||||||
|
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
|
||||||
|
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = indices;
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = indices;
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = indices;
|
||||||
|
} else {
|
||||||
|
all_reduce_fusion_split_indices_[group] = indices;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
||||||
auto iter = all_reduce_fusion_split_indices_.find(group);
|
auto iter = all_reduce_fusion_split_indices_.find(group);
|
||||||
if (iter != all_reduce_fusion_split_indices_.end()) {
|
if (iter != all_reduce_fusion_split_indices_.end()) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
|
@ -154,11 +162,19 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(cons
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) {
|
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group) {
|
||||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos &&
|
||||||
|
group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos &&
|
||||||
|
group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) {
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes;
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes;
|
||||||
|
all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes;
|
||||||
|
} else {
|
||||||
|
all_reduce_fusion_split_indices_[group] = sizes;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
||||||
auto iter = all_reduce_fusion_split_sizes_.find(group);
|
auto iter = all_reduce_fusion_split_sizes_.find(group);
|
||||||
if (iter != all_reduce_fusion_split_sizes_.end()) {
|
if (iter != all_reduce_fusion_split_sizes_.end()) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
|
|
|
@ -100,10 +100,10 @@ class ParallelContext {
|
||||||
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
|
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
|
||||||
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
|
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }
|
||||||
|
|
||||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
|
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
|
||||||
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||||
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
|
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
|
||||||
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
|
std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
|
||||||
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
|
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
|
||||||
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue