forked from mindspore-Ecosystem/mindspore
In a previous PR (https://gitee.com/mindspore/mindspore/pulls/26807/), we replaced 'auto_parallel_search_mode' by 'search_mode' directly.
However, to be forward compatitable, it is suitable to keep 'auto_parallel_search_mode' available. This PR recovers the 'auto_parallel_search_mode' interface and adds a warning when using this old interface. This PR also deals with other codestyle things.
This commit is contained in:
parent
60bfef499f
commit
04db51a528
|
@ -74,6 +74,7 @@ void ParallelContext::Reset() {
|
||||||
optimizer_weight_shard_aggregated_save_ = false;
|
optimizer_weight_shard_aggregated_save_ = false;
|
||||||
enable_all2all_ = false;
|
enable_all2all_ = false;
|
||||||
grad_accumulation_shard_ = true;
|
grad_accumulation_shard_ = true;
|
||||||
|
sharding_propagation_ = false;
|
||||||
dataset_strategy_.clear();
|
dataset_strategy_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,5 +277,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
|
void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
|
||||||
|
|
||||||
|
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -148,6 +148,8 @@ class ParallelContext {
|
||||||
const AbstractBasePtr &ptr);
|
const AbstractBasePtr &ptr);
|
||||||
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node,
|
||||||
const AbstractBasePtr &ptr);
|
const AbstractBasePtr &ptr);
|
||||||
|
void set_sharding_propagation(const bool);
|
||||||
|
bool sharding_propagation() const { return sharding_propagation_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ParallelContext();
|
ParallelContext();
|
||||||
|
@ -183,6 +185,7 @@ class ParallelContext {
|
||||||
std::vector<std::vector<int64_t>> dataset_strategy_;
|
std::vector<std::vector<int64_t>> dataset_strategy_;
|
||||||
bool dataset_repeat_dim_right_ = false;
|
bool dataset_repeat_dim_right_ = false;
|
||||||
bool hccl_test_available_ = false;
|
bool hccl_test_available_ = false;
|
||||||
|
bool sharding_propagation_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -105,14 +105,15 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
|
||||||
// search parallelization strategy
|
// search parallelization strategy
|
||||||
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
|
if ((strategy_search_mode == DYNAMIC_PROGRAMMING) || (strategy_search_mode == SHARDING_PROPAGATION)) {
|
||||||
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
|
if (ParallelStrategySearch(all_nodes, root) != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode";
|
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using " << strategy_search_mode
|
||||||
|
<< " searching mode";
|
||||||
}
|
}
|
||||||
} else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
|
} else if (strategy_search_mode == RECURSIVE_PROGRAMMING) {
|
||||||
if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
|
if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) {
|
||||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
|
MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected";
|
MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected: " << strategy_search_mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
(void)gettimeofday(&end_time, nullptr);
|
(void)gettimeofday(&end_time, nullptr);
|
||||||
|
@ -288,6 +289,14 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
|
||||||
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
(void)configured_stra_ops_.emplace(operator_info, strategyPtr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ApplyApproximationForNode(const OperatorInfoPtr &operator_info) {
|
||||||
|
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||||
|
if (approximation) {
|
||||||
|
operator_info->ApproximateStrategies();
|
||||||
|
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, bool is_last_nodes,
|
||||||
StrategyMap *stra_map) {
|
StrategyMap *stra_map) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
@ -369,8 +378,10 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) &&
|
bool use_sp_and_dataset = ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||||
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos)) {
|
(ParallelContext::GetInstance()->sharding_propagation())) &&
|
||||||
|
(operator_info->name().find(VIRTUAL_DATA_SET_INFO) != std::string::npos);
|
||||||
|
if (use_sp_and_dataset) {
|
||||||
const auto &swc_vec = operator_info->GetStrategyCost();
|
const auto &swc_vec = operator_info->GetStrategyCost();
|
||||||
if (swc_vec.empty()) {
|
if (swc_vec.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
|
MS_LOG(EXCEPTION) << "No available strategy for: " << operator_info->name();
|
||||||
|
@ -379,11 +390,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
|
||||||
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
|
(void)configured_stra_ops_.emplace(operator_info, swc_vec[0]->strategy_ptr);
|
||||||
}
|
}
|
||||||
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
|
// If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated
|
||||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
ApplyApproximationForNode(operator_info);
|
||||||
if (approximation) {
|
|
||||||
operator_info->ApproximateStrategies();
|
|
||||||
MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name();
|
|
||||||
}
|
|
||||||
return operator_info;
|
return operator_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -639,7 +646,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
||||||
node_op_info->AddPrevEdge(edge_ptr);
|
node_op_info->AddPrevEdge(edge_ptr);
|
||||||
prev_op_info->AddSuccEdge(edge_ptr);
|
prev_op_info->AddSuccEdge(edge_ptr);
|
||||||
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
|
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
|
||||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) && (prev_prim->name() == CAST) &&
|
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||||
|
(ParallelContext::GetInstance()->sharding_propagation());
|
||||||
|
if (use_sp && (prev_prim->name() == CAST) &&
|
||||||
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
|
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
|
||||||
const auto next_op_stra = configured_stra_ops_[node_op_info];
|
const auto next_op_stra = configured_stra_ops_[node_op_info];
|
||||||
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
|
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
|
||||||
|
@ -990,7 +999,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: run the strategy searching algorithm
|
// Step 4: run the strategy searching algorithm
|
||||||
if ((ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION)) {
|
bool use_sp = (ParallelContext::GetInstance()->strategy_search_mode() == SHARDING_PROPAGATION) ||
|
||||||
|
(ParallelContext::GetInstance()->sharding_propagation());
|
||||||
|
if (use_sp) {
|
||||||
entire_costgraph->StrategyPropagate(configured_stra_ops_);
|
entire_costgraph->StrategyPropagate(configured_stra_ops_);
|
||||||
configured_stra_ops_.clear();
|
configured_stra_ops_.clear();
|
||||||
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
|
} else if (GetStrategy(entire_costgraph) != SUCCESS) {
|
||||||
|
|
|
@ -211,6 +211,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
"Get whether to integrated save weight shard when enable parallel optimizer.")
|
||||||
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
|
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
|
||||||
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
|
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
|
||||||
|
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
|
||||||
|
"Set sharding strategy propagation value.")
|
||||||
|
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
|
||||||
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
|
||||||
|
|
||||||
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
|
||||||
|
|
|
@ -370,7 +370,7 @@ def _context():
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||||
search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||||
parallel_optimizer_config=dict)
|
parallel_optimizer_config=dict)
|
||||||
|
@ -401,6 +401,7 @@ def set_auto_parallel_context(**kwargs):
|
||||||
enable_parallel_optimizer dataset_strategy
|
enable_parallel_optimizer dataset_strategy
|
||||||
parallel_optimizer_config pipeline_stages
|
parallel_optimizer_config pipeline_stages
|
||||||
\ grad_accumulation_step
|
\ grad_accumulation_step
|
||||||
|
\ auto_parallel_search_mode
|
||||||
=========================== ===========================
|
=========================== ===========================
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -431,6 +432,8 @@ def set_auto_parallel_context(**kwargs):
|
||||||
- dynamic_programming: Dynamic programming search mode.
|
- dynamic_programming: Dynamic programming search mode.
|
||||||
|
|
||||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||||
|
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
|
||||||
|
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
|
||||||
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
|
parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
|
||||||
the same network initialization parameter values for all devices, broadcast the parameters
|
the same network initialization parameter values for all devices, broadcast the parameters
|
||||||
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
|
on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
|
||||||
|
@ -485,6 +488,7 @@ def set_auto_parallel_context(**kwargs):
|
||||||
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
|
||||||
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
|
>>> context.set_auto_parallel_context(search_mode="dynamic_programming")
|
||||||
|
>>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
|
||||||
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
>>> context.set_auto_parallel_context(parameter_broadcast=False)
|
||||||
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
|
||||||
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
|
||||||
|
@ -529,6 +533,7 @@ def reset_auto_parallel_context():
|
||||||
- gradient_fp32_sync: True.
|
- gradient_fp32_sync: True.
|
||||||
- parallel_mode: 'stand_alone'.
|
- parallel_mode: 'stand_alone'.
|
||||||
- search_mode: 'dynamic_programming'.
|
- search_mode: 'dynamic_programming'.
|
||||||
|
- auto_parallel_search_mode: 'dynamic_programming'.
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: ''.
|
- strategy_ckpt_load_file: ''.
|
||||||
- strategy_ckpt_save_file: ''.
|
- strategy_ckpt_save_file: ''.
|
||||||
|
|
|
@ -233,6 +233,55 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_strategy_search_mode()
|
return self._context_handle.get_strategy_search_mode()
|
||||||
|
|
||||||
|
def set_auto_parallel_search_mode(self, search_mode):
|
||||||
|
"""
|
||||||
|
Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
|
||||||
|
MindSpore version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_mode (str): The search mode of strategy.
|
||||||
|
"""
|
||||||
|
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
|
||||||
|
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
|
||||||
|
self.check_context_handle()
|
||||||
|
ret = self._context_handle.set_strategy_search_mode(search_mode)
|
||||||
|
if ret is False:
|
||||||
|
raise ValueError("The context configuration parameter 'search_mode' only support "
|
||||||
|
"'recursive_programming' and 'dynamic_programming', but got the value : {}."
|
||||||
|
.format(search_mode))
|
||||||
|
|
||||||
|
def get_auto_parallel_search_mode(self):
|
||||||
|
"""Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
|
||||||
|
MindSpore version.
|
||||||
|
"""
|
||||||
|
logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
|
||||||
|
"The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_strategy_search_mode()
|
||||||
|
|
||||||
|
def set_sharding_propagation(self, sharding_propagation):
|
||||||
|
"""
|
||||||
|
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
|
||||||
|
will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
|
||||||
|
will search the desired strategies. Default: False.
|
||||||
|
This attribute is replaced by context.set_auto_parallel(search_mode="sharding_propagation").
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sharding_propagation (bool): Enable/disable strategy propagation.
|
||||||
|
"""
|
||||||
|
logger.warning("This attribute is replaced by context.set_auto_parallel(search_mode='sharding_propagation'), "
|
||||||
|
"and this attribute will be deleted in a future MindSpore version.")
|
||||||
|
self.check_context_handle()
|
||||||
|
if not isinstance(sharding_propagation, bool):
|
||||||
|
raise TypeError("The type of parameter 'sharding_propagation' must be bool, "
|
||||||
|
"but got the type : {}.".format(type(sharding_propagation)))
|
||||||
|
self._context_handle.set_sharding_propagation(sharding_propagation)
|
||||||
|
|
||||||
|
def get_sharding_propagation(self):
|
||||||
|
"""Get the value of sharding strategy propagation."""
|
||||||
|
self.check_context_handle()
|
||||||
|
return self._context_handle.get_sharding_propagation()
|
||||||
|
|
||||||
def set_parameter_broadcast(self, parameter_broadcast):
|
def set_parameter_broadcast(self, parameter_broadcast):
|
||||||
"""
|
"""
|
||||||
Set parameter broadcast.
|
Set parameter broadcast.
|
||||||
|
@ -677,6 +726,7 @@ _set_auto_parallel_context_func_map = {
|
||||||
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
"pipeline_stages": auto_parallel_context().set_pipeline_stages,
|
||||||
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
"parallel_mode": auto_parallel_context().set_parallel_mode,
|
||||||
"search_mode": auto_parallel_context().set_strategy_search_mode,
|
"search_mode": auto_parallel_context().set_strategy_search_mode,
|
||||||
|
"auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
|
||||||
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
|
||||||
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
|
||||||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||||
|
@ -690,6 +740,7 @@ _set_auto_parallel_context_func_map = {
|
||||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||||
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
|
||||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
||||||
|
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
||||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
|
"enable_alltoall": auto_parallel_context().set_enable_alltoall}
|
||||||
|
|
||||||
|
|
||||||
|
@ -702,6 +753,7 @@ _get_auto_parallel_context_func_map = {
|
||||||
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
||||||
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
||||||
"search_mode": auto_parallel_context().get_strategy_search_mode,
|
"search_mode": auto_parallel_context().get_strategy_search_mode,
|
||||||
|
"auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
|
||||||
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
|
||||||
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
|
||||||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||||
|
@ -713,15 +765,16 @@ _get_auto_parallel_context_func_map = {
|
||||||
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
|
||||||
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
|
||||||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
|
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
|
||||||
|
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
||||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
|
"enable_alltoall": auto_parallel_context().get_enable_alltoall}
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
|
||||||
loss_repeated_mean=bool, parallel_mode=str, search_mode=str,
|
loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
|
||||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||||
communi_parallel_mode=str, optimizer_weight_shard_size=int,
|
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
||||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
|
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool)
|
||||||
|
|
||||||
def _set_auto_parallel_context(**kwargs):
|
def _set_auto_parallel_context(**kwargs):
|
||||||
|
@ -760,6 +813,8 @@ def _set_auto_parallel_context(**kwargs):
|
||||||
- dynamic_programming: Dynamic programming search mode.
|
- dynamic_programming: Dynamic programming search mode.
|
||||||
|
|
||||||
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
- sharding_propagation: Propagate shardings from configured ops to non-configured ops.
|
||||||
|
auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
|
||||||
|
for forward compatibility, and this attribute will be deleted in a future MindSpore version.
|
||||||
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
|
||||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||||
broadcast. Default: False.
|
broadcast. Default: False.
|
||||||
|
@ -837,6 +892,8 @@ def _reset_auto_parallel_context():
|
||||||
- strategy_ckpt_save_file: ""
|
- strategy_ckpt_save_file: ""
|
||||||
- enable_parallel_optimizer: False
|
- enable_parallel_optimizer: False
|
||||||
- search_mode: dynamic_programming
|
- search_mode: dynamic_programming
|
||||||
|
- auto_parallel_search_mode: dynamic_programming
|
||||||
|
- sharding_propagation: False
|
||||||
- pipeline_stages: 0
|
- pipeline_stages: 0
|
||||||
- gradient_accumulation_shard: True
|
- gradient_accumulation_shard: True
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -77,7 +77,7 @@ def auto_parallel_activation3():
|
||||||
|
|
||||||
def test_auto_parallel_activation4():
|
def test_auto_parallel_activation4():
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0,
|
||||||
search_mode="sharding_propagation")
|
auto_parallel_search_mode="sharding_propagation")
|
||||||
strategy1 = ((4, 4), (4, 4))
|
strategy1 = ((4, 4), (4, 4))
|
||||||
strategy2 = None
|
strategy2 = None
|
||||||
strategy3 = ((8, 2),)
|
strategy3 = ((8, 2),)
|
||||||
|
|
Loading…
Reference in New Issue