!49429 strategy_ckpt_interface_change.
Merge pull request !49429 from yao_yf/strategy_ckpt_interface_change
This commit is contained in:
commit
d5017889cb
|
@ -1287,6 +1287,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
|
|||
<< MAX_RECURSIVE_DEPTH;
|
||||
return {};
|
||||
}
|
||||
bool only_trainable_params = ParallelContext::GetInstance()->stra_file_only_trainable_params();
|
||||
std::vector<AnfNodePtr> node_inputs{node->inputs()};
|
||||
ParameterMap param_names;
|
||||
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
|
||||
|
@ -1294,7 +1295,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_
|
|||
auto input = node_inputs[LongToSize(i)];
|
||||
if (input->isa<Parameter>()) {
|
||||
auto input_parameter = input->cast<ParameterPtr>();
|
||||
if (input_parameter->has_default()) {
|
||||
if (input_parameter->has_default() && (!only_trainable_params || ParameterRequireGrad(input_parameter))) {
|
||||
(void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
|
||||
}
|
||||
} else if (input->isa<CNode>()) {
|
||||
|
|
|
@ -192,6 +192,9 @@ class COMMON_EXPORT ParallelContext {
|
|||
void set_do_transform(const bool);
|
||||
bool do_transform() const { return do_transform_; }
|
||||
|
||||
void set_stra_file_only_trainable_params(const bool);
|
||||
bool stra_file_only_trainable_params() const { return stra_file_only_trainable_params_; }
|
||||
|
||||
private:
|
||||
ParallelContext();
|
||||
bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const;
|
||||
|
@ -239,6 +242,7 @@ class COMMON_EXPORT ParallelContext {
|
|||
bool sharding_propagation_;
|
||||
bool enable_micro_interleaved_ = false;
|
||||
bool do_transform_ = false;
|
||||
bool stra_file_only_trainable_params_ = true;
|
||||
std::string fusion_mode_;
|
||||
};
|
||||
} // namespace mindspore::parallel
|
||||
|
|
|
@ -305,6 +305,10 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_full_batch_is_set", &ParallelContext::full_batch_is_set, "Get whether attr full_batch is set.")
|
||||
.def("set_dataset_strategy", &ParallelContext::set_dataset_strategy, "Set dataset sharding strategy.")
|
||||
.def("get_dataset_strategy", &ParallelContext::dataset_strategy, "Get dataset sharding strategy.")
|
||||
.def("set_stra_file_only_trainable_params", &ParallelContext::set_stra_file_only_trainable_params,
|
||||
"Set strategy ckpt only save trainable params.")
|
||||
.def("get_stra_file_only_trainable_params", &ParallelContext::stra_file_only_trainable_params,
|
||||
"Get strategy ckpt only save trainable params.")
|
||||
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
|
||||
"Set enable/disable parallel optimizer.")
|
||||
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
|
||||
|
|
|
@ -283,5 +283,9 @@ void ParallelContext::set_pipeline_micro_size(const size_t pipeline_micro_size)
|
|||
|
||||
void ParallelContext::set_do_transform(const bool do_transform) { do_transform_ = do_transform; }
|
||||
|
||||
void ParallelContext::set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params) {
|
||||
stra_file_only_trainable_params_ = stra_file_only_trainable_params;
|
||||
}
|
||||
|
||||
void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }
|
||||
} // namespace mindspore::parallel
|
||||
|
|
|
@ -521,7 +521,7 @@ def _context():
|
|||
auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||
parallel_optimizer_config=dict, comm_fusion=dict)
|
||||
parallel_optimizer_config=dict, comm_fusion=dict, strategy_ckpt_config=dict)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
r"""
|
||||
Set auto parallel context, only data parallel supported on CPU.
|
||||
|
@ -549,6 +549,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
enable_alltoall grad_accumulation_step
|
||||
\ auto_parallel_search_mode
|
||||
\ comm_fusion
|
||||
\ strategy_ckpt_config
|
||||
=========================== ===========================
|
||||
|
||||
Args:
|
||||
|
@ -587,8 +588,10 @@ def set_auto_parallel_context(**kwargs):
|
|||
data_parallel mode, all parameters are broadcast except for the parameter whose attribute
|
||||
layerwise_parallel is True. Hybrid_parallel, semi_auto_parallel and auto_parallel mode, the
|
||||
segmented parameters do not participate in broadcasting. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The interface is not to be
|
||||
recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The interface is not to be
|
||||
recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ''
|
||||
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
|
||||
should be set as True. Default: False. The interface is not to be recommended currently,
|
||||
it is better using 'dataset_strategy' to replace it.
|
||||
|
@ -654,6 +657,24 @@ def set_auto_parallel_context(**kwargs):
|
|||
- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
|
||||
and `size`. Config is same as `allgather`.
|
||||
|
||||
strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This
|
||||
interface contains the functions of interface `strategy_ckpt_load_file` and
|
||||
`strategy_ckpt_save_file`, it is recommonded to use this interface to replace those two
|
||||
interfaces.
|
||||
It contains following configurations:
|
||||
|
||||
- load_file (str): The path to load parallel strategy checkpoint. If the file name extension is
|
||||
`.json`, the file is loaded in JSON format. Otherwise, the file is loaded in protobufer
|
||||
format.
|
||||
Default: ''
|
||||
|
||||
- save_file (str): The path to save parallel strategy checkpoint. If the file name extension is
|
||||
`.json`, the file is saved in JSON format. Otherwise, the file is saved in protobufer format.
|
||||
Default: ''
|
||||
|
||||
- only_trainable_params: Only save/load the strategy information for trainable parameters.
|
||||
Default: True.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
||||
|
|
|
@ -509,6 +509,52 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_strategy_ckpt_save_file()
|
||||
|
||||
def set_strategy_ckpt_config(self, strategy_ckpt_config):
|
||||
"""
|
||||
Set strategy checkpoint config.
|
||||
|
||||
Args:
|
||||
strategy_ckpt_config (dict): The strategy checkpoint config.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if not isinstance(strategy_ckpt_config, dict):
|
||||
raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
|
||||
"must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
|
||||
for config_name in strategy_ckpt_config:
|
||||
unknown_config = []
|
||||
if config_name not in ["load_file", "save_file", "only_trainable_params"]:
|
||||
unknown_config.append(config_name)
|
||||
|
||||
if unknown_config:
|
||||
raise ValueError("Unknown config: {}".format(unknown_config))
|
||||
if "load_file" in strategy_ckpt_config:
|
||||
load_file = strategy_ckpt_config.get("load_file")
|
||||
if not isinstance(load_file, str):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
||||
"the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
|
||||
self._context_handle.set_strategy_ckpt_load_file(load_file)
|
||||
if "save_file" in strategy_ckpt_config:
|
||||
save_file = strategy_ckpt_config.get("save_file")
|
||||
if not isinstance(save_file, str):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
||||
"the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
|
||||
self._context_handle.set_strategy_ckpt_save_file(save_file)
|
||||
if "only_trainable_params" in strategy_ckpt_config:
|
||||
only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
|
||||
if not isinstance(only_trainable_params, bool):
|
||||
raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
|
||||
"the argument 'only_trainable_params' must be bool,"
|
||||
" but got the type : {} .".format(type(only_trainable_params)))
|
||||
self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
|
||||
|
||||
def get_strategy_ckpt_config(self):
|
||||
"""Get strategy checkpoint config."""
|
||||
self.check_context_handle()
|
||||
load_file = self._context_handle.get_strategy_ckpt_load_file()
|
||||
save_file = self._context_handle.get_strategy_ckpt_save_file()
|
||||
only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
|
||||
return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
|
||||
|
||||
def set_group_ckpt_save_file(self, group_ckpt_save_file):
|
||||
"""Set group checkpoint save path."""
|
||||
self.check_context_handle()
|
||||
|
@ -1015,6 +1061,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
|
||||
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().set_enable_alltoall,
|
||||
"strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
|
||||
"comm_fusion": auto_parallel_context().set_comm_fusion}
|
||||
|
||||
|
||||
|
@ -1042,6 +1089,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
|
||||
"enable_alltoall": auto_parallel_context().get_enable_alltoall,
|
||||
"comm_fusion": auto_parallel_context().get_comm_fusion,
|
||||
"strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config(),
|
||||
"full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
|
||||
|
||||
|
||||
|
@ -1051,7 +1099,8 @@ _get_auto_parallel_context_func_map = {
|
|||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
|
||||
communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict)
|
||||
optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
|
||||
strategy_ckpt_config=dict)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_six_matmul_save():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"save_file": "./strategy_stage1.ckpt"},
|
||||
group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((1, 8), (8, 1))
|
||||
|
@ -137,7 +137,7 @@ def six_matmul_load():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"load_file": "./strategy_stage1.ckpt"},
|
||||
group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy3 = ((8, 1), (1, 1))
|
||||
|
@ -183,7 +183,7 @@ def test_six_matmul_save_auto():
|
|||
self.matmul4 = P.MatMul()
|
||||
self.matmul5 = P.MatMul()
|
||||
self.matmul6 = P.MatMul()
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
|
||||
self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1", requires_grad=False)
|
||||
self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
|
||||
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
|
||||
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
|
||||
|
@ -201,7 +201,9 @@ def test_six_matmul_save_auto():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.json")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
strategy_ckpt_config={"save_file": "./strategy_stage1_auto.json",
|
||||
"only_trainable_params": True})
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", dataset_strategy="full_batch")
|
||||
x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
@ -256,7 +258,8 @@ def six_matmul_load_auto():
|
|||
return out
|
||||
|
||||
reset_auto_parallel_context()
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.json")
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
strategy_ckpt_config={"load_file": "./strategy_stage1_auto.json"})
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy3 = ((2, 2), (2, 2))
|
||||
strategy4 = ((2, 2), (2, 2))
|
||||
|
|
Loading…
Reference in New Issue