In a previous PR (https://gitee.com/mindspore/mindspore/pulls/26807/), we replaced 'auto_parallel_search_mode' by 'search_mode' directly.

However, to be forward compatitable, it is suitable to keep 'auto_parallel_search_mode' available. This PR recovers the 'auto_parallel_search_mode' interface and adds a warning when using this old interface.

This PR also deals with other codestyle things.
This commit is contained in:
Xiaoda Zhang 2021-11-29 15:17:51 +08:00
parent 60bfef499f
commit 04db51a528
7 changed files with 97 additions and 15 deletions

View File

@ -74,6 +74,7 @@ void ParallelContext::Reset() {
optimizer_weight_shard_aggregated_save_ = false;
enable_all2all_ = false;
grad_accumulation_shard_ = true;
sharding_propagation_ = false;
dataset_strategy_.clear();
}
@ -276,5 +277,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
}
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
} // namespace parallel
} // namespace mindspore

View File

@ -148,6 +148,8 @@ class ParallelContext {
const AbstractBasePtr &ptr);
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
void set_sharding_propagation(const bool);
bool sharding_propagation() const { return sharding_propagation_; }
private:
ParallelContext();
@ -183,6 +185,7 @@ class ParallelContext {
std::vector<std::vector<int64_t>> dataset_strategy_;
bool dataset_repeat_dim_right_ = false;
bool hccl_test_available_ = false;
bool sharding_propagation_;
};
} // namespace parallel

View File

@ -105,14 +105,15 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
// search parallelization strategy
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
<< " searching mode";
}
} else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
}
} else {
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
}
(void)gettimeofday(&end_time, nullptr);
@ -288,6 +289,14 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
}
void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
}
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
StrategyMap *stra_map) {
MS_EXCEPTION_IF_NULL(prim);
@ -369,8 +378,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return nullptr;
}
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(ParallelContext::GetInstance()->sharding_propagation())) &&
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
if (use_sp_and_dataset) {
const auto &swc_vec = operator_info->GetStrategyCost();
if (swc_vec.empty()) {
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
@ -379,11 +390,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
}
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
operator_info->ApproximateStrategies();
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
}
ApplyApproximationForNode(operator_info);
return operator_info;
}
@ -639,7 +646,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) &&
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(ParallelContext::GetInstance()->sharding_propagation());
if (use_sp && (prev_prim->name() == CAST) &&
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
const auto next_op_stra = configured_stra_ops_[node_op_info];
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
@ -990,7 +999,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
}
// Step 4: run the strategy searching algorithm
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) {
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
(ParallelContext::GetInstance()->sharding_propagation());
if (use_sp) {
entire_costgraph->StrategyPropagate(configured_stra_ops_);
configured_stra_ops_.clear();
} else if (GetStrategy(entire_costgraph) != SUCCESS) {

View File

@ -211,6 +211,9 @@ PYBIND11_MODULE(_c_expression, m) {
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
"Set sharding strategy propagation value.")
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")

View File

@ -370,7 +370,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict)
@ -401,6 +401,7 @@ def set_auto_parallel_context(**kwargs):
enable_parallel_optimizer dataset_strategy
parallel_optimizer_config pipeline_stages
\ grad_accumulation_step
\ auto_parallel_search_mode
=========================== ===========================
Args:
@ -431,6 +432,8 @@ def set_auto_parallel_context(**kwargs):
- dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
the same network initialization parameter values for all devices, broadcast the parameters
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
@ -485,6 +488,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
@ -529,6 +533,7 @@ def reset_auto_parallel_context():
- gradient_fp32_sync: True.
- parallel_mode: 'stand_alone'.
- search_mode: 'dynamic_programming'.
- auto_parallel_search_mode: 'dynamic_programming'.
- parameter_broadcast: False.
- strategy_ckpt_load_file: ''.
- strategy_ckpt_save_file: ''.

View File

@ -233,6 +233,55 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_strategy_search_mode()
def set_auto_parallel_search_mode(self, search_mode):
"""
Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
MindSpore version.
Args:
search_mode (str): The search mode of strategy.
"""
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
self.check_context_handle()
ret = self._context_handle.set_strategy_search_mode(search_mode)
if ret is False:
raise ValueError("The context configuration parameter 'search_mode' only support "
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
.format(search_mode))
def get_auto_parallel_search_mode(self):
"""Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
MindSpore version.
"""
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
self.check_context_handle()
return self._context_handle.get_strategy_search_mode()
def set_sharding_propagation(self, sharding_propagation):
"""
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
will search the desired strategies. Default: False.
This attribute is replaced by context.set_auto_parallel(search_mode="sharding_propagation").
Args:
sharding_propagation (bool): Enable/disable strategy propagation.
"""
logger.warning("This attribute is replaced by context.set_auto_parallel(search_mode='sharding_propagation'), "
"and this attribute will be deleted in a future MindSpore version.")
self.check_context_handle()
if not isinstance(sharding_propagation, bool):
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
"but got the type : {}.".format(type(sharding_propagation)))
self._context_handle.set_sharding_propagation(sharding_propagation)
def get_sharding_propagation(self):
"""Get the value of sharding strategy propagation."""
self.check_context_handle()
return self._context_handle.get_sharding_propagation()
def set_parameter_broadcast(self, parameter_broadcast):
"""
Set parameter broadcast.
@ -677,6 +726,7 @@ _set_auto_parallel_context_func_map = {
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"search_mode": auto_parallel_context().set_strategy_search_mode,
"auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
@ -690,6 +740,7 @@ _set_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
@ -702,6 +753,7 @@ _get_auto_parallel_context_func_map = {
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"search_mode": auto_parallel_context().get_strategy_search_mode,
"auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
@ -713,15 +765,16 @@ _get_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, search_mode=str,
loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int,
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
def _set_auto_parallel_context(**kwargs):
@ -760,6 +813,8 @@ def _set_auto_parallel_context(**kwargs):
- dynamic_programming: Dynamic programming search mode.
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
@ -837,6 +892,8 @@ def _reset_auto_parallel_context():
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
- search_mode: dynamic_programming
- auto_parallel_search_mode: dynamic_programming
- sharding_propagation: False
- pipeline_stages: 0
- gradient_accumulation_shard: True
"""

View File

@ -77,7 +77,7 @@ def auto_parallel_activation3():
def test_auto_parallel_activation4():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
search_mode="sharding_propagation")
auto_parallel_search_mode="sharding_propagation")
strategy1 = ((4, 4), (4, 4))
strategy2 = None
strategy3 = ((8, 2),)