In a previous PR (https://gitee.com/mindspore/mindspore/pulls/26807/), we replaced 'auto_parallel_search_mode' by 'search_mode' directly.

However, to be forward compatitable, it is suitable to keep 'auto_parallel_search_mode' available. This PR recovers the 'auto_parallel_search_mode' interface and adds a warning when using this old interface.

This PR also deals with other codestyle things.
This commit is contained in:
Xiaoda Zhang 2021-11-29 15:17:51 +08:00
parent 60bfef499f
commit 04db51a528
7 changed files with 97 additions and 15 deletions

View File

@ -74,6 +74,7 @@ void ParallelContext::Reset() {
optimizer_weight_shard_aggregated_save_ = false; optimizer_weight_shard_aggregated_save_ = false;
enable_all2all_ = false; enable_all2all_ = false;
grad_accumulation_shard_ = true; grad_accumulation_shard_ = true;
sharding_propagation_ = false;
dataset_strategy_.clear(); dataset_strategy_.clear();
} }
@ -276,5 +277,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
} }
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; } void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -148,6 +148,8 @@ class ParallelContext {
const AbstractBasePtr &ptr); const AbstractBasePtr &ptr);
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node, void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr); const AbstractBasePtr &ptr);
void set_sharding_propagation(const bool);
bool sharding_propagation() const { return sharding_propagation_; }
private: private:
ParallelContext(); ParallelContext();
@ -183,6 +185,7 @@ class ParallelContext {
std::vector<std::vector<int64_t>> dataset_strategy_; std::vector<std::vector<int64_t>> dataset_strategy_;
bool dataset_repeat_dim_right_ = false; bool dataset_repeat_dim_right_ = false;
bool hccl_test_available_ = false; bool hccl_test_available_ = false;
bool sharding_propagation_;
}; };
} // namespace parallel } // namespace parallel

View File

@ -105,14 +105,15 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
// search parallelization strategy // search parallelization strategy
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) { if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) { if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode"; MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
<< " searching mode";
} }
} else if (strategy_search_mode == RECURSIVE_PROGRAMMING) { } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) { if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode"; MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
} }
} else { } else {
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected"; MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
} }
(void)gettimeofday(&end_time, nullptr); (void)gettimeofday(&end_time, nullptr);
@ -288,6 +289,14 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
(void)configured_stra_ops_.emplace(operator_info, strategyPtr); (void)configured_stra_ops_.emplace(operator_info, strategyPtr);
} }
void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes, OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
StrategyMap *stra_map) { StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -369,8 +378,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return nullptr; return nullptr;
} }
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) { (ParallelContext::GetInstance()->sharding_propagation())) &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
if (use_sp_and_dataset) {
const auto &swc_vec = operator_info->GetStrategyCost(); const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) { if (swc_vec.empty()) {
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name(); MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
@ -379,11 +390,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr); (void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
} }
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); ApplyApproximationForNode(operator_info);
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
return operator_info; return operator_info;
} }
@ -639,7 +646,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
node_op_info->AddPrevEdge(edge_ptr); node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr); prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) && bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(ParallelContext::GetInstance()->sharding_propagation());
if (use_sp && (prev_prim->name() == CAST) &&
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) { (configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
const auto next_op_stra = configured_stra_ops_[node_op_info]; const auto next_op_stra = configured_stra_ops_[node_op_info];
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra); const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
@ -990,7 +999,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
} }
// Step 4: run the strategy searching algorithm // Step 4: run the strategy searching algorithm
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) { bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(ParallelContext::GetInstance()->sharding_propagation());
if (use_sp) {
entire_costgraph->StrategyPropagate(configured_stra_ops_); entire_costgraph->StrategyPropagate(configured_stra_ops_);
configured_stra_ops_.clear(); configured_stra_ops_.clear();
} else if (GetStrategy(entire_costgraph) != SUCCESS) { } else if (GetStrategy(entire_costgraph) != SUCCESS) {

View File

@ -211,6 +211,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Get whether to integrated save weight shard when enable parallel optimizer.") "Get whether to integrated save weight shard when enable parallel optimizer.")
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.") .def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.") .def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
"Set sharding strategy propagation value.")
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context."); .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -370,7 +370,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int, all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict) parallel_optimizer_config=dict)
@ -401,6 +401,7 @@ def set_auto_parallel_context(**kwargs):
enable_parallel_optimizer dataset_strategy enable_parallel_optimizer dataset_strategy
parallel_optimizer_config pipeline_stages parallel_optimizer_config pipeline_stages
\ grad_accumulation_step \ grad_accumulation_step
\ auto_parallel_search_mode
=========================== =========================== =========================== ===========================
Args: Args:
@ -431,6 +432,8 @@ def set_auto_parallel_context(**kwargs):
- dynamic_programming: Dynamic programming search mode. - dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops. - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
the same network initialization parameter values for all devices, broadcast the parameters the same network initialization parameter values for all devices, broadcast the parameters
on device 0 to other devices. Parameter broadcasting in different parallel modes is different, on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
@ -485,6 +488,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(gradient_fp32_sync=False) >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(search_mode="dynamic_programming") >>> context.set_auto_parallel_context(search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
@ -529,6 +533,7 @@ def reset_auto_parallel_context():
- gradient_fp32_sync: True. - gradient_fp32_sync: True.
- parallel_mode: 'stand_alone'. - parallel_mode: 'stand_alone'.
- search_mode: 'dynamic_programming'. - search_mode: 'dynamic_programming'.
- auto_parallel_search_mode: 'dynamic_programming'.
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: ''. - strategy_ckpt_load_file: ''.
- strategy_ckpt_save_file: ''. - strategy_ckpt_save_file: ''.

View File

@ -233,6 +233,55 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_strategy_search_mode() return self._context_handle.get_strategy_search_mode()
def set_auto_parallel_search_mode(self, search_mode):
"""
Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
MindSpore version.
Args:
search_mode (str): The search mode of strategy.
"""
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
self.check_context_handle()
ret = self._context_handle.set_strategy_search_mode(search_mode)
if ret is False:
raise ValueError("The context configuration parameter 'search_mode' only support "
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
.format(search_mode))
def get_auto_parallel_search_mode(self):
"""Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
MindSpore version.
"""
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
self.check_context_handle()
return self._context_handle.get_strategy_search_mode()
def set_sharding_propagation(self, sharding_propagation):
"""
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
will search the desired strategies. Default: False.
This attribute is replaced by context.set_auto_parallel(search_mode="sharding_propagation").
Args:
sharding_propagation (bool): Enable/disable strategy propagation.
"""
logger.warning("This attribute is replaced by context.set_auto_parallel(search_mode='sharding_propagation'), "
"and this attribute will be deleted in a future MindSpore version.")
self.check_context_handle()
if not isinstance(sharding_propagation, bool):
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
"but got the type : {}.".format(type(sharding_propagation)))
self._context_handle.set_sharding_propagation(sharding_propagation)
def get_sharding_propagation(self):
"""Get the value of sharding strategy propagation."""
self.check_context_handle()
return self._context_handle.get_sharding_propagation()
def set_parameter_broadcast(self, parameter_broadcast): def set_parameter_broadcast(self, parameter_broadcast):
""" """
Set parameter broadcast. Set parameter broadcast.
@ -677,6 +726,7 @@ _set_auto_parallel_context_func_map = {
"pipeline_stages": auto_parallel_context().set_pipeline_stages, "pipeline_stages": auto_parallel_context().set_pipeline_stages,
"parallel_mode": auto_parallel_context().set_parallel_mode, "parallel_mode": auto_parallel_context().set_parallel_mode,
"search_mode": auto_parallel_context().set_strategy_search_mode, "search_mode": auto_parallel_context().set_strategy_search_mode,
"auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
@ -690,6 +740,7 @@ _set_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall} "enable_alltoall": auto_parallel_context().set_enable_alltoall}
@ -702,6 +753,7 @@ _get_auto_parallel_context_func_map = {
"pipeline_stages": auto_parallel_context().get_pipeline_stages, "pipeline_stages": auto_parallel_context().get_pipeline_stages,
"parallel_mode": auto_parallel_context().get_parallel_mode, "parallel_mode": auto_parallel_context().get_parallel_mode,
"search_mode": auto_parallel_context().get_strategy_search_mode, "search_mode": auto_parallel_context().get_strategy_search_mode,
"auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast, "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
@ -713,15 +765,16 @@ _get_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall} "enable_alltoall": auto_parallel_context().get_enable_alltoall}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, search_mode=str, loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int, communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool) optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
@ -760,6 +813,8 @@ def _set_auto_parallel_context(**kwargs):
- dynamic_programming: Dynamic programming search mode. - dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops. - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False. broadcast. Default: False.
@ -837,6 +892,8 @@ def _reset_auto_parallel_context():
- strategy_ckpt_save_file: "" - strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False - enable_parallel_optimizer: False
- search_mode: dynamic_programming - search_mode: dynamic_programming
- auto_parallel_search_mode: dynamic_programming
- sharding_propagation: False
- pipeline_stages: 0 - pipeline_stages: 0
- gradient_accumulation_shard: True - gradient_accumulation_shard: True
""" """

View File

@ -77,7 +77,7 @@ def auto_parallel_activation3():
def test_auto_parallel_activation4(): def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0, context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
search_mode="sharding_propagation") auto_parallel_search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4)) strategy1 = ((4, 4), (4, 4))
strategy2 = None strategy2 = None
strategy3 = ((8, 2),) strategy3 = ((8, 2),)