Add communication parallel mode.

This commit is contained in:
liujunzhu 2021-03-05 10:03:31 +08:00
parent 5135c214b7
commit 6541b96c40
6 changed files with 108 additions and 5 deletions

View File

@ -33,6 +33,9 @@ std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRI
AUTO_PARALLEL};
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
NO_GROUP_PARALLEL};
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
@ -65,6 +68,7 @@ void ParallelContext::Reset() {
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
pipeline_stage_split_num_ = 1;
grad_accumulation_step_ = 1;
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -152,6 +156,17 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
return {};
}
bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
return false;
}
communi_parallel_mode_ = communi_parallel_mode;
return true;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -46,6 +46,10 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
constexpr char TRAINING[] = "training";
constexpr char ACCUMULATION[] = "accumulation";
constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel";
constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
class ParallelContext {
public:
~ParallelContext() = default;
@ -112,6 +116,9 @@ class ParallelContext {
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
@ -144,6 +151,7 @@ class ParallelContext {
std::string group_ckpt_save_file_;
bool enable_parallel_optimizer_;
bool init_param_shape_;
std::string communi_parallel_mode_;
};
} // namespace parallel

View File

@ -169,6 +169,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -40,6 +40,7 @@ namespace ascend {
namespace {
constexpr uint32_t kDeviceNumOfServer = 8;
constexpr uint32_t kDeviceNumThreshold = 1024;
const char kDefaultGroup[] = "__default_group";
constexpr uint32_t kMaxStreamNum = 1024;
constexpr uint32_t kHcomSecondaryStreamNum = 3;
@ -60,13 +61,48 @@ bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
}
string DoGetHcomGroup(const string &original_group) {
string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode();
if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) {
return original_group;
}
if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) {
return kDefaultGroup;
}
MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
auto group_info = parallel::g_device_manager->group_info();
for (const auto &info : group_info) {
if (info.first != original_group) {
continue;
}
const auto &rank_ids = info.second;
if (IsSameServer(rank_ids)) {
return original_group;
} else {
return kDefaultGroup;
}
}
// world group is not in group_info.
return kDefaultGroup;
}
string GetHcomGroup(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
}
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
auto new_group = DoGetHcomGroup(group_name);
MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name
<< ", new group: " << new_group;
return new_group;
}
uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
@ -167,6 +203,9 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
<< ".";
Reset();
SetLoopSink();
ReorderIndependentOrders(graph_ptr);

View File

@ -480,6 +480,26 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def set_communi_parallel_mode(self, communi_parallel_mode):
"""
Set communication parallel mode.
Args:
communi_parallel_mode (str): The communication parallel mode.
Raises:
ValueError: If parallel mode is not supported.
"""
self.check_context_handle()
ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
if ret is False:
raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
def get_communi_parallel_mode(self):
"""Get communication parallel mode."""
self.check_context_handle()
return self._context_handle.get_communi_parallel_mode()
def reset(self):
"""Reset all settings."""
self.check_context_handle()
@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = {
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
_get_auto_parallel_context_func_map = {
@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = {
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str)
def _set_auto_parallel_context(**kwargs):
"""
@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs):
the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when
parall mode semi_auto_parallel is enabled. Default: 0
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
- all_group_parallel: All communication groups are in parallel.
- same_server_group_parallel: Only the communication groups within the same server are parallel.
- no_group_parallel: All communication groups are not parallel.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -21,19 +21,22 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def test_set_auto_parallel_context():
context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False,
parallel_mode="auto_parallel", parameter_broadcast=False)
parallel_mode="auto_parallel", parameter_broadcast=False,
communi_parallel_mode="same_server_group_parallel")
device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank")
gradients_mean = context.get_auto_parallel_context("gradients_mean")
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
assert device_num == 4
assert global_rank == 3
assert gradients_mean
assert not gradient_fp32_sync
assert parallel_mode == "auto_parallel"
assert not parameter_broadcast
assert communi_parallel_mode == "same_server_group_parallel"
auto_parallel_context().set_device_num(4)
device_num = auto_parallel_context().get_device_num()
@ -77,6 +80,9 @@ def test_set_auto_parallel_context():
with pytest.raises(ValueError):
set_algo_parameters(tensor_slice_align_size=1025)
with pytest.raises(ValueError):
context.set_auto_parallel_context(communi_parallel_mode="wrong_mode")
context.set_auto_parallel_context(enable_parallel_optimizer=True)
assert context.get_auto_parallel_context("enable_parallel_optimizer")
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
@ -98,6 +104,7 @@ def test_reset_auto_parallel_context():
device_num_is_set = auto_parallel_context().get_device_num_is_set()
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
stage = auto_parallel_context().get_pipeline_stages()
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
assert device_num == 1
assert global_rank == 0
@ -108,3 +115,4 @@ def test_reset_auto_parallel_context():
assert not device_num_is_set
assert not parameter_broadcast_is_set
assert stage == 1
assert communi_parallel_mode == "all_group_parallel"