forked from mindspore-Ecosystem/mindspore
!1135 add group for allreduce fusion
Merge pull request !1135 from kisnwang/add-group-for-allreduce-fusion
This commit is contained in:
commit
311b7e71af
|
@ -113,20 +113,28 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
|
|||
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
|
||||
}
|
||||
|
||||
void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) {
|
||||
all_reduce_fusion_split_indices_ = indices;
|
||||
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
|
||||
all_reduce_fusion_split_indices_[group] = indices;
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_indices() const {
|
||||
return all_reduce_fusion_split_indices_;
|
||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
|
||||
auto iter = all_reduce_fusion_split_indices_.find(group);
|
||||
if (iter != all_reduce_fusion_split_indices_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void ParallelContext::set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes) {
|
||||
all_reduce_fusion_split_sizes_ = sizes;
|
||||
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) {
|
||||
all_reduce_fusion_split_sizes_[group] = sizes;
|
||||
}
|
||||
|
||||
const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_sizes() const {
|
||||
return all_reduce_fusion_split_sizes_;
|
||||
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
|
||||
auto iter = all_reduce_fusion_split_sizes_.find(group);
|
||||
if (iter != all_reduce_fusion_split_sizes_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -76,10 +77,10 @@ class ParallelContext {
|
|||
bool global_rank_is_set() const { return global_rank_is_set_; }
|
||||
bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
|
||||
|
||||
void set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices);
|
||||
const std::vector<uint32_t> all_reduce_fusion_split_indices() const;
|
||||
void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes);
|
||||
const std::vector<uint32_t> all_reduce_fusion_split_sizes() const;
|
||||
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
|
||||
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
|
||||
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
|
||||
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
|
||||
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
|
||||
}
|
||||
|
@ -108,8 +109,8 @@ class ParallelContext {
|
|||
bool global_rank_is_set_;
|
||||
bool parameter_broadcast_is_set_;
|
||||
bool enable_all_reduce_fusion_;
|
||||
std::vector<uint32_t> all_reduce_fusion_split_indices_;
|
||||
std::vector<uint32_t> all_reduce_fusion_split_sizes_;
|
||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
|
||||
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
|
||||
std::string strategy_ckpt_load_file_;
|
||||
std::string strategy_ckpt_save_file_;
|
||||
};
|
||||
|
|
|
@ -159,13 +159,13 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
|
||||
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
|
||||
.def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
|
||||
.def("set_all_reduce_fusion_split_indices", &ParallelContext::set_all_reduce_fusion_split_indices,
|
||||
.def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
|
||||
"Set all reduce fusion split indices.")
|
||||
.def("get_all_reduce_fusion_split_indices", &ParallelContext::all_reduce_fusion_split_indices,
|
||||
.def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices,
|
||||
"Get all reduce fusion split indices.")
|
||||
.def("set_all_reduce_fusion_split_sizes", &ParallelContext::set_all_reduce_fusion_split_sizes,
|
||||
.def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes,
|
||||
"Set all reduce fusion split sizes.")
|
||||
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes,
|
||||
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes,
|
||||
"Get all reduce fusion split sizes.")
|
||||
.def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
|
||||
"Set enable/disable all reduce fusion.")
|
||||
|
|
|
@ -92,7 +92,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|||
} // namespace
|
||||
|
||||
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const {
|
||||
std::vector<size_t> *segment_index, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(segment_num);
|
||||
MS_EXCEPTION_IF_NULL(segment_index);
|
||||
size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
|
||||
|
@ -100,7 +100,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|||
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
const std::vector<uint32_t> split_indices = parallel_context->all_reduce_fusion_split_indices();
|
||||
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
|
||||
|
||||
size_t segments = 0;
|
||||
if (split_indices.size() != 0) {
|
||||
|
@ -255,7 +255,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
size_t segment_num = 0;
|
||||
std::vector<size_t> segment_index;
|
||||
if (GetSplitSegments(it.second, &segment_num, &segment_index)) {
|
||||
if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {
|
||||
if (DoFusion(func_graph, it.second, segment_num, segment_index)) {
|
||||
changed = true;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ class CommunicationOpFusion : public Pass {
|
|||
const CommunicationOpInfo &communication_op_info, size_t start_index,
|
||||
size_t end_index) const;
|
||||
bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
||||
std::vector<size_t> *segment_index) const;
|
||||
std::vector<size_t> *segment_index, const std::string &group) const;
|
||||
std::string op_name_;
|
||||
size_t groups_ = 1;
|
||||
};
|
||||
|
|
|
@ -19,6 +19,8 @@ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx,
|
|||
from mindspore._c_expression import AutoParallelContext
|
||||
from mindspore._checkparam import args_type_check
|
||||
|
||||
_MAX_GROUP_NAME_LEN = 127
|
||||
|
||||
|
||||
class _AutoParallelContext:
|
||||
"""
|
||||
|
@ -243,51 +245,117 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_parameter_broadcast_is_set()
|
||||
|
||||
def set_all_reduce_fusion_split_indices(self, indices):
|
||||
def set_all_reduce_fusion_split_indices(self, indices, group=""):
|
||||
"""
|
||||
Set allreduce fusion strategy by parameters indices.
|
||||
|
||||
Args:
|
||||
indices (list): Indices list.
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If type of indices item is not int.
|
||||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
for index in indices:
|
||||
if not isinstance(index, int):
|
||||
raise TypeError('indices has invalid value')
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
||||
if isinstance(indices, (list)):
|
||||
for index in indices:
|
||||
if not isinstance(index, int):
|
||||
raise TypeError('indices has invalid value')
|
||||
else:
|
||||
raise TypeError('indices must be a python list')
|
||||
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_idx(indices)
|
||||
if group == "":
|
||||
_set_fusion_strategy_by_idx(indices)
|
||||
else:
|
||||
_set_fusion_strategy_by_idx(indices, group)
|
||||
|
||||
def get_all_reduce_fusion_split_indices(self):
|
||||
"""Get allreduce fusion split indices."""
|
||||
def get_all_reduce_fusion_split_indices(self, group=""):
|
||||
"""
|
||||
Get allreduce fusion split indices.
|
||||
|
||||
Args:
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Returns:
|
||||
Return split sizes list according to the group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_all_reduce_fusion_split_indices()
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
||||
|
||||
def set_all_reduce_fusion_split_sizes(self, sizes):
|
||||
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
||||
"""
|
||||
Set allreduce fusion strategy by parameters data sizes.
|
||||
|
||||
Args:
|
||||
sizes (list): Sizes list.
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If type of sizes item is not int.
|
||||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
for size in sizes:
|
||||
if not isinstance(size, int):
|
||||
raise TypeError('sizes has invalid value')
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_size(sizes)
|
||||
if isinstance(sizes, (list)):
|
||||
for size in sizes:
|
||||
if not isinstance(size, int):
|
||||
raise TypeError('sizes has invalid value')
|
||||
else:
|
||||
raise TypeError('sizes must be a python list')
|
||||
|
||||
def get_all_reduce_fusion_split_sizes(self):
|
||||
"""Get allreduce fusion split sizes."""
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
if group == "":
|
||||
_set_fusion_strategy_by_size(sizes)
|
||||
else:
|
||||
_set_fusion_strategy_by_size(sizes, group)
|
||||
|
||||
def get_all_reduce_fusion_split_sizes(self, group=""):
|
||||
"""
|
||||
Get allreduce fusion split sizes.
|
||||
|
||||
Args:
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Returns:
|
||||
Return split sizes list according to the group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_all_reduce_fusion_split_sizes()
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
|
||||
|
||||
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue