forked from mindspore-Ecosystem/mindspore
Add communication parallel mode.
This commit is contained in:
parent
5135c214b7
commit
6541b96c40
|
@ -33,6 +33,9 @@ std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRI
|
|||
AUTO_PARALLEL};
|
||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
||||
|
||||
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
|
||||
NO_GROUP_PARALLEL};
|
||||
|
||||
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
|
||||
|
||||
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
||||
|
@ -65,6 +68,7 @@ void ParallelContext::Reset() {
|
|||
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
||||
pipeline_stage_split_num_ = 1;
|
||||
grad_accumulation_step_ = 1;
|
||||
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int64_t device_num) {
|
||||
|
@ -152,6 +156,17 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
|
|||
return {};
|
||||
}
|
||||
|
||||
bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
|
||||
auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
|
||||
if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
|
||||
MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
|
||||
return false;
|
||||
}
|
||||
|
||||
communi_parallel_mode_ = communi_parallel_mode;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -46,6 +46,10 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
|||
constexpr char TRAINING[] = "training";
|
||||
constexpr char ACCUMULATION[] = "accumulation";
|
||||
|
||||
constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel";
|
||||
constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
|
||||
constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
|
||||
|
||||
class ParallelContext {
|
||||
public:
|
||||
~ParallelContext() = default;
|
||||
|
@ -112,6 +116,9 @@ class ParallelContext {
|
|||
}
|
||||
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||
|
||||
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||
|
||||
void Reset();
|
||||
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
|
||||
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
|
@ -144,6 +151,7 @@ class ParallelContext {
|
|||
std::string group_ckpt_save_file_;
|
||||
bool enable_parallel_optimizer_;
|
||||
bool init_param_shape_;
|
||||
std::string communi_parallel_mode_;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -169,6 +169,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set enable/disable parallel optimizer.")
|
||||
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
||||
"Get enable/disable parallel optimizer.")
|
||||
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
|
||||
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -40,6 +40,7 @@ namespace ascend {
|
|||
namespace {
|
||||
constexpr uint32_t kDeviceNumOfServer = 8;
|
||||
constexpr uint32_t kDeviceNumThreshold = 1024;
|
||||
const char kDefaultGroup[] = "__default_group";
|
||||
|
||||
constexpr uint32_t kMaxStreamNum = 1024;
|
||||
constexpr uint32_t kHcomSecondaryStreamNum = 3;
|
||||
|
@ -60,13 +61,48 @@ bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
|
|||
return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
|
||||
}
|
||||
|
||||
string DoGetHcomGroup(const string &original_group) {
|
||||
string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode();
|
||||
|
||||
if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) {
|
||||
return original_group;
|
||||
}
|
||||
|
||||
if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) {
|
||||
return kDefaultGroup;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
|
||||
auto group_info = parallel::g_device_manager->group_info();
|
||||
for (const auto &info : group_info) {
|
||||
if (info.first != original_group) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &rank_ids = info.second;
|
||||
if (IsSameServer(rank_ids)) {
|
||||
return original_group;
|
||||
} else {
|
||||
return kDefaultGroup;
|
||||
}
|
||||
}
|
||||
|
||||
// world group is not in group_info.
|
||||
return kDefaultGroup;
|
||||
}
|
||||
|
||||
string GetHcomGroup(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
||||
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
|
||||
}
|
||||
|
||||
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
||||
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
||||
auto new_group = DoGetHcomGroup(group_name);
|
||||
MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name
|
||||
<< ", new group: " << new_group;
|
||||
|
||||
return new_group;
|
||||
}
|
||||
|
||||
uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
|
||||
|
@ -167,6 +203,9 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u
|
|||
|
||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
|
||||
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
|
||||
<< ".";
|
||||
|
||||
Reset();
|
||||
SetLoopSink();
|
||||
ReorderIndependentOrders(graph_ptr);
|
||||
|
|
|
@ -480,6 +480,26 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_parallel_optimizer()
|
||||
|
||||
def set_communi_parallel_mode(self, communi_parallel_mode):
|
||||
"""
|
||||
Set communication parallel mode.
|
||||
|
||||
Args:
|
||||
communi_parallel_mode (str): The communication parallel mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If parallel mode is not supported.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
|
||||
if ret is False:
|
||||
raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
|
||||
|
||||
def get_communi_parallel_mode(self):
|
||||
"""Get communication parallel mode."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_communi_parallel_mode()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all settings."""
|
||||
self.check_context_handle()
|
||||
|
@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = {
|
|||
"full_batch": auto_parallel_context().set_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
|
||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
|
||||
|
||||
|
||||
_get_auto_parallel_context_func_map = {
|
||||
|
@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = {
|
|||
"full_batch": auto_parallel_context().get_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
|
||||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs):
|
|||
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||
'pipeline_stags' stages. This currently could only be used when
|
||||
parall mode semi_auto_parallel is enabled. Default: 0
|
||||
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
|
||||
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
|
||||
|
||||
- all_group_parallel: All communication groups are in parallel.
|
||||
|
||||
- same_server_group_parallel: Only the communication groups within the same server are parallel.
|
||||
|
||||
- no_group_parallel: All communication groups are not parallel.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -21,19 +21,22 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|||
|
||||
def test_set_auto_parallel_context():
|
||||
context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False,
|
||||
parallel_mode="auto_parallel", parameter_broadcast=False)
|
||||
parallel_mode="auto_parallel", parameter_broadcast=False,
|
||||
communi_parallel_mode="same_server_group_parallel")
|
||||
device_num = context.get_auto_parallel_context("device_num")
|
||||
global_rank = context.get_auto_parallel_context("global_rank")
|
||||
gradients_mean = context.get_auto_parallel_context("gradients_mean")
|
||||
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
||||
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||
assert device_num == 4
|
||||
assert global_rank == 3
|
||||
assert gradients_mean
|
||||
assert not gradient_fp32_sync
|
||||
assert parallel_mode == "auto_parallel"
|
||||
assert not parameter_broadcast
|
||||
assert communi_parallel_mode == "same_server_group_parallel"
|
||||
|
||||
auto_parallel_context().set_device_num(4)
|
||||
device_num = auto_parallel_context().get_device_num()
|
||||
|
@ -77,6 +80,9 @@ def test_set_auto_parallel_context():
|
|||
with pytest.raises(ValueError):
|
||||
set_algo_parameters(tensor_slice_align_size=1025)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
context.set_auto_parallel_context(communi_parallel_mode="wrong_mode")
|
||||
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
|
@ -98,6 +104,7 @@ def test_reset_auto_parallel_context():
|
|||
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
||||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||
stage = auto_parallel_context().get_pipeline_stages()
|
||||
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||
|
||||
assert device_num == 1
|
||||
assert global_rank == 0
|
||||
|
@ -108,3 +115,4 @@ def test_reset_auto_parallel_context():
|
|||
assert not device_num_is_set
|
||||
assert not parameter_broadcast_is_set
|
||||
assert stage == 1
|
||||
assert communi_parallel_mode == "all_group_parallel"
|
||||
|
|
Loading…
Reference in New Issue