forked from mindspore-Ecosystem/mindspore
Add communication parallel mode.
This commit is contained in:
parent
5135c214b7
commit
6541b96c40
|
@ -33,6 +33,9 @@ std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRI
|
||||||
AUTO_PARALLEL};
|
AUTO_PARALLEL};
|
||||||
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
|
||||||
|
|
||||||
|
std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL,
|
||||||
|
NO_GROUP_PARALLEL};
|
||||||
|
|
||||||
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
|
std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr;
|
||||||
|
|
||||||
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
|
||||||
|
@ -65,6 +68,7 @@ void ParallelContext::Reset() {
|
||||||
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
||||||
pipeline_stage_split_num_ = 1;
|
pipeline_stage_split_num_ = 1;
|
||||||
grad_accumulation_step_ = 1;
|
grad_accumulation_step_ = 1;
|
||||||
|
communi_parallel_mode_ = ALL_GROUP_PARALLEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::set_device_num(int64_t device_num) {
|
void ParallelContext::set_device_num(int64_t device_num) {
|
||||||
|
@ -152,6 +156,17 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) {
|
||||||
|
auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode);
|
||||||
|
if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) {
|
||||||
|
MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
communi_parallel_mode_ = communi_parallel_mode;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||||
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
|
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
|
@ -46,6 +46,10 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
|
||||||
constexpr char TRAINING[] = "training";
|
constexpr char TRAINING[] = "training";
|
||||||
constexpr char ACCUMULATION[] = "accumulation";
|
constexpr char ACCUMULATION[] = "accumulation";
|
||||||
|
|
||||||
|
constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel";
|
||||||
|
constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
|
||||||
|
constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
|
||||||
|
|
||||||
class ParallelContext {
|
class ParallelContext {
|
||||||
public:
|
public:
|
||||||
~ParallelContext() = default;
|
~ParallelContext() = default;
|
||||||
|
@ -112,6 +116,9 @@ class ParallelContext {
|
||||||
}
|
}
|
||||||
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
|
||||||
|
|
||||||
|
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||||
|
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||||
|
|
||||||
void Reset();
|
void Reset();
|
||||||
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
|
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
|
||||||
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
|
@ -144,6 +151,7 @@ class ParallelContext {
|
||||||
std::string group_ckpt_save_file_;
|
std::string group_ckpt_save_file_;
|
||||||
bool enable_parallel_optimizer_;
|
bool enable_parallel_optimizer_;
|
||||||
bool init_param_shape_;
|
bool init_param_shape_;
|
||||||
|
std::string communi_parallel_mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -169,6 +169,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
"Set enable/disable parallel optimizer.")
|
"Set enable/disable parallel optimizer.")
|
||||||
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
||||||
"Get enable/disable parallel optimizer.")
|
"Get enable/disable parallel optimizer.")
|
||||||
|
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
|
||||||
|
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
|
||||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||||
|
|
||||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||||
|
|
|
@ -40,6 +40,7 @@ namespace ascend {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr uint32_t kDeviceNumOfServer = 8;
|
constexpr uint32_t kDeviceNumOfServer = 8;
|
||||||
constexpr uint32_t kDeviceNumThreshold = 1024;
|
constexpr uint32_t kDeviceNumThreshold = 1024;
|
||||||
|
const char kDefaultGroup[] = "__default_group";
|
||||||
|
|
||||||
constexpr uint32_t kMaxStreamNum = 1024;
|
constexpr uint32_t kMaxStreamNum = 1024;
|
||||||
constexpr uint32_t kHcomSecondaryStreamNum = 3;
|
constexpr uint32_t kHcomSecondaryStreamNum = 3;
|
||||||
|
@ -60,13 +61,48 @@ bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
|
||||||
return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
|
return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string DoGetHcomGroup(const string &original_group) {
|
||||||
|
string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode();
|
||||||
|
|
||||||
|
if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) {
|
||||||
|
return original_group;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) {
|
||||||
|
return kDefaultGroup;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
|
||||||
|
auto group_info = parallel::g_device_manager->group_info();
|
||||||
|
for (const auto &info : group_info) {
|
||||||
|
if (info.first != original_group) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &rank_ids = info.second;
|
||||||
|
if (IsSameServer(rank_ids)) {
|
||||||
|
return original_group;
|
||||||
|
} else {
|
||||||
|
return kDefaultGroup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// world group is not in group_info.
|
||||||
|
return kDefaultGroup;
|
||||||
|
}
|
||||||
|
|
||||||
string GetHcomGroup(const CNodePtr &cnode) {
|
string GetHcomGroup(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
||||||
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
|
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
|
||||||
}
|
}
|
||||||
|
|
||||||
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
||||||
|
auto new_group = DoGetHcomGroup(group_name);
|
||||||
|
MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name
|
||||||
|
<< ", new group: " << new_group;
|
||||||
|
|
||||||
|
return new_group;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
|
uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
|
||||||
|
@ -167,6 +203,9 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u
|
||||||
|
|
||||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
|
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
|
||||||
|
MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
|
||||||
|
<< ".";
|
||||||
|
|
||||||
Reset();
|
Reset();
|
||||||
SetLoopSink();
|
SetLoopSink();
|
||||||
ReorderIndependentOrders(graph_ptr);
|
ReorderIndependentOrders(graph_ptr);
|
||||||
|
|
|
@ -480,6 +480,26 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_enable_parallel_optimizer()
|
return self._context_handle.get_enable_parallel_optimizer()
|
||||||
|
|
||||||
|
def set_communi_parallel_mode(self, communi_parallel_mode):
|
||||||
|
"""
|
||||||
|
Set communication parallel mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
communi_parallel_mode (str): The communication parallel mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If parallel mode is not supported.
|
||||||
|
"""
|
||||||
|
self.check_context_handle()
|
||||||
|
ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
|
||||||
|
if ret is False:
|
||||||
|
raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
|
||||||
|
|
||||||
|
def get_communi_parallel_mode(self):
|
||||||
|
"""Get communication parallel mode."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_communi_parallel_mode()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset all settings."""
|
"""Reset all settings."""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
|
@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = {
|
||||||
"full_batch": auto_parallel_context().set_full_batch,
|
"full_batch": auto_parallel_context().set_full_batch,
|
||||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
|
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
||||||
|
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
|
||||||
|
|
||||||
|
|
||||||
_get_auto_parallel_context_func_map = {
|
_get_auto_parallel_context_func_map = {
|
||||||
|
@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = {
|
||||||
"full_batch": auto_parallel_context().get_full_batch,
|
"full_batch": auto_parallel_context().get_full_batch,
|
||||||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
||||||
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
||||||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
|
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
|
||||||
|
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
|
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||||
|
communi_parallel_mode=str)
|
||||||
|
|
||||||
def _set_auto_parallel_context(**kwargs):
|
def _set_auto_parallel_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs):
|
||||||
the devices are distributed alone the pipeline. The total devices will be divided into
|
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||||
'pipeline_stags' stages. This currently could only be used when
|
'pipeline_stags' stages. This currently could only be used when
|
||||||
parall mode semi_auto_parallel is enabled. Default: 0
|
parall mode semi_auto_parallel is enabled. Default: 0
|
||||||
|
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
|
||||||
|
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
|
||||||
|
|
||||||
|
- all_group_parallel: All communication groups are in parallel.
|
||||||
|
|
||||||
|
- same_server_group_parallel: Only the communication groups within the same server are parallel.
|
||||||
|
|
||||||
|
- no_group_parallel: All communication groups are not parallel.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
|
|
@ -21,19 +21,22 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
|
||||||
def test_set_auto_parallel_context():
|
def test_set_auto_parallel_context():
|
||||||
context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False,
|
context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False,
|
||||||
parallel_mode="auto_parallel", parameter_broadcast=False)
|
parallel_mode="auto_parallel", parameter_broadcast=False,
|
||||||
|
communi_parallel_mode="same_server_group_parallel")
|
||||||
device_num = context.get_auto_parallel_context("device_num")
|
device_num = context.get_auto_parallel_context("device_num")
|
||||||
global_rank = context.get_auto_parallel_context("global_rank")
|
global_rank = context.get_auto_parallel_context("global_rank")
|
||||||
gradients_mean = context.get_auto_parallel_context("gradients_mean")
|
gradients_mean = context.get_auto_parallel_context("gradients_mean")
|
||||||
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
|
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
|
||||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
|
||||||
|
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||||
assert device_num == 4
|
assert device_num == 4
|
||||||
assert global_rank == 3
|
assert global_rank == 3
|
||||||
assert gradients_mean
|
assert gradients_mean
|
||||||
assert not gradient_fp32_sync
|
assert not gradient_fp32_sync
|
||||||
assert parallel_mode == "auto_parallel"
|
assert parallel_mode == "auto_parallel"
|
||||||
assert not parameter_broadcast
|
assert not parameter_broadcast
|
||||||
|
assert communi_parallel_mode == "same_server_group_parallel"
|
||||||
|
|
||||||
auto_parallel_context().set_device_num(4)
|
auto_parallel_context().set_device_num(4)
|
||||||
device_num = auto_parallel_context().get_device_num()
|
device_num = auto_parallel_context().get_device_num()
|
||||||
|
@ -77,6 +80,9 @@ def test_set_auto_parallel_context():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
set_algo_parameters(tensor_slice_align_size=1025)
|
set_algo_parameters(tensor_slice_align_size=1025)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
context.set_auto_parallel_context(communi_parallel_mode="wrong_mode")
|
||||||
|
|
||||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||||
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||||
|
@ -98,6 +104,7 @@ def test_reset_auto_parallel_context():
|
||||||
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
device_num_is_set = auto_parallel_context().get_device_num_is_set()
|
||||||
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
|
||||||
stage = auto_parallel_context().get_pipeline_stages()
|
stage = auto_parallel_context().get_pipeline_stages()
|
||||||
|
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
|
||||||
|
|
||||||
assert device_num == 1
|
assert device_num == 1
|
||||||
assert global_rank == 0
|
assert global_rank == 0
|
||||||
|
@ -108,3 +115,4 @@ def test_reset_auto_parallel_context():
|
||||||
assert not device_num_is_set
|
assert not device_num_is_set
|
||||||
assert not parameter_broadcast_is_set
|
assert not parameter_broadcast_is_set
|
||||||
assert stage == 1
|
assert stage == 1
|
||||||
|
assert communi_parallel_mode == "all_group_parallel"
|
||||||
|
|
Loading…
Reference in New Issue