!49429 strategy_ckpt_interface_change.

Merge pull request !49429 from yao_yf/strategy_ckpt_interface_change
This commit is contained in:
i-robot 2023-03-07 02:15:12 +00:00 committed by Gitee
commit d5017889cb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 96 additions and 10 deletions

View File

@ -1287,6 +1287,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
<< MAX_RECURSIVE_DEPTH;
return {};
}
bool only_trainable_params = ParallelContext::GetInstance()->stra_file_only_trainable_params();
std::vector<AnfNodePtr> node_inputs{node->inputs()};
ParameterMap param_names;
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
@ -1294,7 +1295,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
auto input = node_inputs[LongToSize(i)];
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
if (input_parameter->has_default() && (!only_trainable_params || ParameterRequireGrad(input_parameter))) {
(void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
}
} else if (input->isa<CNode>()) {

View File

@ -192,6 +192,9 @@ class COMMON_EXPORT ParallelContext {
void set_do_transform(const bool);
bool do_transform() const { return do_transform_; }
void set_stra_file_only_trainable_params(const bool);
bool stra_file_only_trainable_params() const { return stra_file_only_trainable_params_; }
private:
ParallelContext();
bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const;
@ -239,6 +242,7 @@ class COMMON_EXPORT ParallelContext {
bool sharding_propagation_;
bool enable_micro_interleaved_ = false;
bool do_transform_ = false;
bool stra_file_only_trainable_params_ = true;
std::string fusion_mode_;
};
} // namespace mindspore::parallel

View File

@ -305,6 +305,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_full_batch_is_set", &ParallelContext::full_batch_is_set, "Get whether attr full_batch is set.")
.def("set_dataset_strategy", &ParallelContext::set_dataset_strategy, "Set dataset sharding strategy.")
.def("get_dataset_strategy", &ParallelContext::dataset_strategy, "Get dataset sharding strategy.")
.def("set_stra_file_only_trainable_params", &ParallelContext::set_stra_file_only_trainable_params,
"Set strategy ckpt only save trainable params.")
.def("get_stra_file_only_trainable_params", &ParallelContext::stra_file_only_trainable_params,
"Get strategy ckpt only save trainable params.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,

View File

@ -283,5 +283,9 @@ void ParallelContext::set_pipeline_micro_size(const size_t pipeline_micro_size)
void ParallelContext::set_do_transform(const bool do_transform) { do_transform_ = do_transform; }
void ParallelContext::set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params) {
stra_file_only_trainable_params_ = stra_file_only_trainable_params;
}
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
} // namespace mindspore::parallel

View File

@ -521,7 +521,7 @@ def _context():
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict, comm_fusion=dict)
parallel_optimizer_config=dict, comm_fusion=dict, strategy_ckpt_config=dict)
def set_auto_parallel_context(**kwargs):
r"""
Set auto parallel context, only data parallel supported on CPU.
@ -549,6 +549,7 @@ def set_auto_parallel_context(**kwargs):
enable_alltoall grad_accumulation_step
\ auto_parallel_search_mode
\ comm_fusion
\ strategy_ckpt_config
=========================== ===========================
Args:
@ -587,8 +588,10 @@ def set_auto_parallel_context(**kwargs):
data_parallel mode, all parameters are broadcast except for the parameter whose attribute
layerwise_parallel is True. Hybrid_parallel, semi_auto_parallel and auto_parallel mode, the
segmented parameters do not participate in broadcasting. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The interface is not to be
recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The interface is not to be
recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ''
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
should be set as True. Default: False. The interface is not to be recommended currently,
it is better using 'dataset_strategy' to replace it.
@ -654,6 +657,24 @@ def set_auto_parallel_context(**kwargs):
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
and `size`. Config is same as `allgather`.
strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This
interface contains the functions of interface `strategy_ckpt_load_file` and
`strategy_ckpt_save_file`, it is recommonded to use this interface to replace those two
interfaces.
It contains following configurations:
- load_file (str): The path to load parallel strategy checkpoint. If the file name extension is
`.json`, the file is loaded in JSON format. Otherwise, the file is loaded in protobufer
format.
Default: ''
- save_file (str): The path to save parallel strategy checkpoint. If the file name extension is
`.json`, the file is saved in JSON format. Otherwise, the file is saved in protobufer format.
Default: ''
- only_trainable_params: Only save/load the strategy information for trainable parameters.
Default: True.
Raises:
ValueError: If input key is not attribute in auto parallel context.

View File

@ -509,6 +509,52 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_strategy_ckpt_save_file()
def set_strategy_ckpt_config(self, strategy_ckpt_config):
"""
Set strategy checkpoint config.
Args:
strategy_ckpt_config (dict): The strategy checkpoint config.
"""
self.check_context_handle()
if not isinstance(strategy_ckpt_config, dict):
raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
"must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
for config_name in strategy_ckpt_config:
unknown_config = []
if config_name not in ["load_file", "save_file", "only_trainable_params"]:
unknown_config.append(config_name)
if unknown_config:
raise ValueError("Unknown config: {}".format(unknown_config))
if "load_file" in strategy_ckpt_config:
load_file = strategy_ckpt_config.get("load_file")
if not isinstance(load_file, str):
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
"the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
self._context_handle.set_strategy_ckpt_load_file(load_file)
if "save_file" in strategy_ckpt_config:
save_file = strategy_ckpt_config.get("save_file")
if not isinstance(save_file, str):
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
"the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
self._context_handle.set_strategy_ckpt_save_file(save_file)
if "only_trainable_params" in strategy_ckpt_config:
only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
if not isinstance(only_trainable_params, bool):
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
"the argument 'only_trainable_params' must be bool,"
" but got the type : {} .".format(type(only_trainable_params)))
self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
def get_strategy_ckpt_config(self):
"""Get strategy checkpoint config."""
self.check_context_handle()
load_file = self._context_handle.get_strategy_ckpt_load_file()
save_file = self._context_handle.get_strategy_ckpt_save_file()
only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
def set_group_ckpt_save_file(self, group_ckpt_save_file):
"""Set group checkpoint save path."""
self.check_context_handle()
@ -1015,6 +1061,7 @@ _set_auto_parallel_context_func_map = {
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
"comm_fusion": auto_parallel_context().set_comm_fusion}
@ -1042,6 +1089,7 @@ _get_auto_parallel_context_func_map = {
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
"comm_fusion": auto_parallel_context().get_comm_fusion,
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config(),
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
@ -1051,7 +1099,8 @@ _get_auto_parallel_context_func_map = {
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict)
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
strategy_ckpt_config=dict)
def _set_auto_parallel_context(**kwargs):
"""

View File

@ -75,7 +75,7 @@ def test_six_matmul_save():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"save_file": "./strategy_stage1.ckpt"},
group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((1, 8), (8, 1))
@ -137,7 +137,7 @@ def six_matmul_load():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"load_file": "./strategy_stage1.ckpt"},
group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch")
strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1))
@ -183,7 +183,7 @@ def test_six_matmul_save_auto():
self.matmul4 = P.MatMul()
self.matmul5 = P.MatMul()
self.matmul6 = P.MatMul()
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1", requires_grad=False)
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
@ -201,7 +201,9 @@ def test_six_matmul_save_auto():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.json")
set_auto_parallel_context(device_num=8, global_rank=0,
strategy_ckpt_config={"save_file": "./strategy_stage1_auto.json",
"only_trainable_params": True})
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel", dataset_strategy="full_batch")
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
@ -256,7 +258,8 @@ def six_matmul_load_auto():
return out
reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.json")
set_auto_parallel_context(device_num=8, global_rank=0,
strategy_ckpt_config={"load_file": "./strategy_stage1_auto.json"})
strategy1 = ((2, 2), (2, 2))
strategy3 = ((2, 2), (2, 2))
strategy4 = ((2, 2), (2, 2))