In a previous PR (https://gitee.com/mindspore/mindspore/pulls/26807/), we replaced 'auto_parallel_search_mode' by 'search_mode' directly.
However, to be forward compatitable, it is suitable to keep 'auto_parallel_search_mode' available. This PR recovers the 'auto_parallel_search_mode' interface and adds a warning when using this old interface. This PR also deals with other codestyle things.
This commit is contained in:
parent
60bfef499f
commit
04db51a528
|
@ -74,6 +74,7 @@ void ParallelContext::Reset() {
|
|||
optimizer_weight_shard_aggregated_save_ = false;
|
||||
enable_all2all_ = false;
|
||||
grad_accumulation_shard_ = true;
|
||||
sharding_propagation_ = false;
|
||||
dataset_strategy_.clear();
|
||||
}
|
||||
|
||||
|
@ -276,5 +277,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
|
|||
}
|
||||
|
||||
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
|
||||
|
||||
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -148,6 +148,8 @@ class ParallelContext {
|
|||
const AbstractBasePtr &ptr);
|
||||
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||
const AbstractBasePtr &ptr);
|
||||
void set_sharding_propagation(const bool);
|
||||
bool sharding_propagation() const { return sharding_propagation_; }
|
||||
|
||||
private:
|
||||
ParallelContext();
|
||||
|
@ -183,6 +185,7 @@ class ParallelContext {
|
|||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||
bool dataset_repeat_dim_right_ = false;
|
||||
bool hccl_test_available_ = false;
|
||||
bool sharding_propagation_;
|
||||
};
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -105,14 +105,15 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
|||
// search parallelization strategy
|
||||
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
|
||||
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
|
||||
<< " searching mode";
|
||||
}
|
||||
} else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
|
||||
if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
|
||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
|
||||
}
|
||||
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
|
@ -288,6 +289,14 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
|
|||
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
||||
}
|
||||
|
||||
void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
operator_info->ApproximateStrategies();
|
||||
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
||||
}
|
||||
}
|
||||
|
||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
||||
StrategyMap *stra_map) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -369,8 +378,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) &&
|
||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
|
||||
bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||
(ParallelContext::GetInstance()->sharding_propagation())) &&
|
||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
|
||||
if (use_sp_and_dataset) {
|
||||
const auto &swc_vec = operator_info->GetStrategyCost();
|
||||
if (swc_vec.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
|
||||
|
@ -379,11 +390,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
|||
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
|
||||
}
|
||||
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
operator_info->ApproximateStrategies();
|
||||
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
||||
}
|
||||
ApplyApproximationForNode(operator_info);
|
||||
return operator_info;
|
||||
}
|
||||
|
||||
|
@ -639,7 +646,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
|||
node_op_info->AddPrevEdge(edge_ptr);
|
||||
prev_op_info->AddSuccEdge(edge_ptr);
|
||||
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) &&
|
||||
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||
(ParallelContext::GetInstance()->sharding_propagation());
|
||||
if (use_sp && (prev_prim->name() == CAST) &&
|
||||
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
|
||||
const auto next_op_stra = configured_stra_ops_[node_op_info];
|
||||
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
|
||||
|
@ -990,7 +999,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|||
}
|
||||
|
||||
// Step 4: run the strategy searching algorithm
|
||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) {
|
||||
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||
(ParallelContext::GetInstance()->sharding_propagation());
|
||||
if (use_sp) {
|
||||
entire_costgraph->StrategyPropagate(configured_stra_ops_);
|
||||
configured_stra_ops_.clear();
|
||||
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
|
||||
|
|
|
@ -211,6 +211,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
||||
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
|
||||
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
|
||||
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
|
||||
"Set sharding strategy propagation value.")
|
||||
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
|
||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||
|
||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||
|
|
|
@ -370,7 +370,7 @@ def _context():
|
|||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||
parallel_optimizer_config=dict)
|
||||
|
@ -401,6 +401,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
enable_parallel_optimizer dataset_strategy
|
||||
parallel_optimizer_config pipeline_stages
|
||||
\ grad_accumulation_step
|
||||
\ auto_parallel_search_mode
|
||||
=========================== ===========================
|
||||
|
||||
Args:
|
||||
|
@ -431,6 +432,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
- dynamic_programming: Dynamic programming search mode.
|
||||
|
||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
|
||||
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
|
||||
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
|
||||
the same network initialization parameter values for all devices, broadcast the parameters
|
||||
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
|
||||
|
@ -485,6 +488,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
|
||||
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
|
||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||
|
@ -529,6 +533,7 @@ def reset_auto_parallel_context():
|
|||
- gradient_fp32_sync: True.
|
||||
- parallel_mode: 'stand_alone'.
|
||||
- search_mode: 'dynamic_programming'.
|
||||
- auto_parallel_search_mode: 'dynamic_programming'.
|
||||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: ''.
|
||||
- strategy_ckpt_save_file: ''.
|
||||
|
|
|
@ -233,6 +233,55 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_search_mode()
|
||||
|
||||
def set_auto_parallel_search_mode(self, search_mode):
|
||||
"""
|
||||
Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
|
||||
MindSpore version.
|
||||
|
||||
Args:
|
||||
search_mode (str): The search mode of strategy.
|
||||
"""
|
||||
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
|
||||
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
|
||||
self.check_context_handle()
|
||||
ret = self._context_handle.set_strategy_search_mode(search_mode)
|
||||
if ret is False:
|
||||
raise ValueError("The context configuration parameter 'search_mode' only support "
|
||||
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
|
||||
.format(search_mode))
|
||||
|
||||
def get_auto_parallel_search_mode(self):
|
||||
"""Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
|
||||
MindSpore version.
|
||||
"""
|
||||
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
|
||||
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_search_mode()
|
||||
|
||||
def set_sharding_propagation(self, sharding_propagation):
|
||||
"""
|
||||
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
|
||||
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
|
||||
will search the desired strategies. Default: False.
|
||||
This attribute is replaced by context.set_auto_parallel(search_mode="sharding_propagation").
|
||||
|
||||
Args:
|
||||
sharding_propagation (bool): Enable/disable strategy propagation.
|
||||
"""
|
||||
logger.warning("This attribute is replaced by context.set_auto_parallel(search_mode='sharding_propagation'), "
|
||||
"and this attribute will be deleted in a future MindSpore version.")
|
||||
self.check_context_handle()
|
||||
if not isinstance(sharding_propagation, bool):
|
||||
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
|
||||
"but got the type : {}.".format(type(sharding_propagation)))
|
||||
self._context_handle.set_sharding_propagation(sharding_propagation)
|
||||
|
||||
def get_sharding_propagation(self):
|
||||
"""Get the value of sharding strategy propagation."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_sharding_propagation()
|
||||
|
||||
def set_parameter_broadcast(self, parameter_broadcast):
|
||||
"""
|
||||
Set parameter broadcast.
|
||||
|
@ -677,6 +726,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||
"search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
|
@ -690,6 +740,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
|
||||
|
||||
|
||||
|
@ -702,6 +753,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||
"search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||
"auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
|
||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
|
@ -713,15 +765,16 @@ _get_auto_parallel_context_func_map = {
|
|||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
||||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
|
||||
|
||||
|
||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||
loss_repeated_mean=bool, parallel_mode=str, search_mode=str,
|
||||
loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
|
@ -760,6 +813,8 @@ def _set_auto_parallel_context(**kwargs):
|
|||
- dynamic_programming: Dynamic programming search mode.
|
||||
|
||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
|
||||
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
|
||||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
|
@ -837,6 +892,8 @@ def _reset_auto_parallel_context():
|
|||
- strategy_ckpt_save_file: ""
|
||||
- enable_parallel_optimizer: False
|
||||
- search_mode: dynamic_programming
|
||||
- auto_parallel_search_mode: dynamic_programming
|
||||
- sharding_propagation: False
|
||||
- pipeline_stages: 0
|
||||
- gradient_accumulation_shard: True
|
||||
"""
|
||||
|
|
|
@ -77,7 +77,7 @@ def auto_parallel_activation3():
|
|||
|
||||
def test_auto_parallel_activation4():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||
search_mode="sharding_propagation")
|
||||
auto_parallel_search_mode="sharding_propagation")
|
||||
strategy1 = ((4, 4), (4, 4))
|
||||
strategy2 = None
|
||||
strategy3 = ((8, 2),)
|
||||
|
|
Loading…
Reference in New Issue