From 08406743968a539408789f876ba250c43f6acb47 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Wed, 22 Feb 2023 16:40:21 +0800 Subject: [PATCH] strategy file consider parameters trainable or not --- .../frontend/parallel/step_parallel_utils.cc | 3 +- .../include/common/utils/parallel_context.h | 4 ++ mindspore/ccsrc/pipeline/jit/init.cc | 4 ++ mindspore/ccsrc/utils/parallel_context.cc | 4 ++ mindspore/python/mindspore/context.py | 27 ++++++++-- .../parallel/_auto_parallel_context.py | 51 ++++++++++++++++++- .../parallel/test_strategy_checkpoint.py | 13 +++-- 7 files changed, 96 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc index ccfc9d1541a..e0b1879f0e5 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc @@ -1287,6 +1287,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_ << MAX_RECURSIVE_DEPTH; return {}; } + bool only_trainable_params = ParallelContext::GetInstance()->stra_file_only_trainable_params(); std::vector node_inputs{node->inputs()}; ParameterMap param_names; for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { @@ -1294,7 +1295,7 @@ ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_ auto input = node_inputs[LongToSize(i)]; if (input->isa()) { auto input_parameter = input->cast(); - if (input_parameter->has_default()) { + if (input_parameter->has_default() && (!only_trainable_params || ParameterRequireGrad(input_parameter))) { (void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter)); } } else if (input->isa()) { diff --git a/mindspore/ccsrc/include/common/utils/parallel_context.h b/mindspore/ccsrc/include/common/utils/parallel_context.h index 72dbbc68a53..6e50a98f869 100644 --- a/mindspore/ccsrc/include/common/utils/parallel_context.h +++ b/mindspore/ccsrc/include/common/utils/parallel_context.h @@ -192,6 +192,9 @@ class COMMON_EXPORT ParallelContext { void set_do_transform(const bool); bool do_transform() const { return do_transform_; } + void set_stra_file_only_trainable_params(const bool); + bool stra_file_only_trainable_params() const { return stra_file_only_trainable_params_; } + private: ParallelContext(); bool ParallelContextCareGraph(const FuncGraphPtr &func_graph) const; @@ -239,6 +242,7 @@ class COMMON_EXPORT ParallelContext { bool sharding_propagation_; bool enable_micro_interleaved_ = false; bool do_transform_ = false; + bool stra_file_only_trainable_params_ = true; std::string fusion_mode_; }; } // namespace mindspore::parallel diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index a87b5e8e98f..99f154781c1 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -305,6 +305,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_full_batch_is_set", &ParallelContext::full_batch_is_set, "Get whether attr full_batch is set.") .def("set_dataset_strategy", &ParallelContext::set_dataset_strategy, "Set dataset sharding strategy.") .def("get_dataset_strategy", &ParallelContext::dataset_strategy, "Get dataset sharding strategy.") + .def("set_stra_file_only_trainable_params", &ParallelContext::set_stra_file_only_trainable_params, + "Set strategy ckpt only save trainable params.") + .def("get_stra_file_only_trainable_params", &ParallelContext::stra_file_only_trainable_params, + "Get strategy ckpt only save trainable params.") .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, "Set enable/disable parallel optimizer.") .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, diff --git a/mindspore/ccsrc/utils/parallel_context.cc b/mindspore/ccsrc/utils/parallel_context.cc index 6390a37e026..8af8e29cacd 100644 --- a/mindspore/ccsrc/utils/parallel_context.cc +++ b/mindspore/ccsrc/utils/parallel_context.cc @@ -283,5 +283,9 @@ void ParallelContext::set_pipeline_micro_size(const size_t pipeline_micro_size) void ParallelContext::set_do_transform(const bool do_transform) { do_transform_ = do_transform; } +void ParallelContext::set_stra_file_only_trainable_params(const bool stra_file_only_trainable_params) { + stra_file_only_trainable_params_ = stra_file_only_trainable_params; +} + void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; } } // namespace mindspore::parallel diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index dc95a93a1df..25849c5b87f 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -521,7 +521,7 @@ def _context(): auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool, all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int, - parallel_optimizer_config=dict, comm_fusion=dict) + parallel_optimizer_config=dict, comm_fusion=dict, strategy_ckpt_config=dict) def set_auto_parallel_context(**kwargs): r""" Set auto parallel context, only data parallel supported on CPU. @@ -549,6 +549,7 @@ def set_auto_parallel_context(**kwargs): enable_alltoall grad_accumulation_step \ auto_parallel_search_mode \ comm_fusion + \ strategy_ckpt_config =========================== =========================== Args: @@ -587,8 +588,10 @@ def set_auto_parallel_context(**kwargs): data_parallel mode, all parameters are broadcast except for the parameter whose attribute layerwise_parallel is True. Hybrid_parallel, semi_auto_parallel and auto_parallel mode, the segmented parameters do not participate in broadcasting. Default: False. - strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' - strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' + strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The interface is not to be + recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: '' + strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The interface is not to be + recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: '' full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter should be set as True. Default: False. The interface is not to be recommended currently, it is better using 'dataset_strategy' to replace it. @@ -654,6 +657,24 @@ def set_auto_parallel_context(**kwargs): - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto` and `size`. Config is same as `allgather`. + strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This + interface contains the functions of interface `strategy_ckpt_load_file` and + `strategy_ckpt_save_file`, it is recommonded to use this interface to replace those two + interfaces. + It contains following configurations: + + - load_file (str): The path to load parallel strategy checkpoint. If the file name extension is + `.json`, the file is loaded in JSON format. Otherwise, the file is loaded in protobufer + format. + Default: '' + + - save_file (str): The path to save parallel strategy checkpoint. If the file name extension is + `.json`, the file is saved in JSON format. Otherwise, the file is saved in protobufer format. + Default: '' + + - only_trainable_params: Only save/load the strategy information for trainable parameters. + Default: True. + Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/python/mindspore/parallel/_auto_parallel_context.py b/mindspore/python/mindspore/parallel/_auto_parallel_context.py index bd13248c1e3..f2bc7e5f43f 100644 --- a/mindspore/python/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/python/mindspore/parallel/_auto_parallel_context.py @@ -509,6 +509,52 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_strategy_ckpt_save_file() + def set_strategy_ckpt_config(self, strategy_ckpt_config): + """ + Set strategy checkpoint config. + + Args: + strategy_ckpt_config (dict): The strategy checkpoint config. + """ + self.check_context_handle() + if not isinstance(strategy_ckpt_config, dict): + raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' " + "must be dict, but got the type : {}.".format(type(strategy_ckpt_config))) + for config_name in strategy_ckpt_config: + unknown_config = [] + if config_name not in ["load_file", "save_file", "only_trainable_params"]: + unknown_config.append(config_name) + + if unknown_config: + raise ValueError("Unknown config: {}".format(unknown_config)) + if "load_file" in strategy_ckpt_config: + load_file = strategy_ckpt_config.get("load_file") + if not isinstance(load_file, str): + raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " + "the argument 'load_file' must be str, but got the type : {} .".format(type(load_file))) + self._context_handle.set_strategy_ckpt_load_file(load_file) + if "save_file" in strategy_ckpt_config: + save_file = strategy_ckpt_config.get("save_file") + if not isinstance(save_file, str): + raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " + "the argument 'save_file' must be str, but got the type : {} .".format(type(save_file))) + self._context_handle.set_strategy_ckpt_save_file(save_file) + if "only_trainable_params" in strategy_ckpt_config: + only_trainable_params = strategy_ckpt_config.get("only_trainable_params") + if not isinstance(only_trainable_params, bool): + raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " + "the argument 'only_trainable_params' must be bool," + " but got the type : {} .".format(type(only_trainable_params))) + self._context_handle.set_stra_file_only_trainable_params(only_trainable_params) + + def get_strategy_ckpt_config(self): + """Get strategy checkpoint config.""" + self.check_context_handle() + load_file = self._context_handle.get_strategy_ckpt_load_file() + save_file = self._context_handle.get_strategy_ckpt_save_file() + only_trainable_param = self._context_handle.get_stra_file_only_trainable_params() + return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param} + def set_group_ckpt_save_file(self, group_ckpt_save_file): """Set group checkpoint save path.""" self.check_context_handle() @@ -1015,6 +1061,7 @@ _set_auto_parallel_context_func_map = { "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, "sharding_propagation": auto_parallel_context().set_sharding_propagation, "enable_alltoall": auto_parallel_context().set_enable_alltoall, + "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config, "comm_fusion": auto_parallel_context().set_comm_fusion} @@ -1042,6 +1089,7 @@ _get_auto_parallel_context_func_map = { "sharding_propagation": auto_parallel_context().get_sharding_propagation, "enable_alltoall": auto_parallel_context().get_enable_alltoall, "comm_fusion": auto_parallel_context().get_comm_fusion, + "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config(), "full_batch_is_set": auto_parallel_context().get_full_batch_is_set} @@ -1051,7 +1099,8 @@ _get_auto_parallel_context_func_map = { strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool, - optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict) + optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict, + strategy_ckpt_config=dict) def _set_auto_parallel_context(**kwargs): """ diff --git a/tests/ut/python/parallel/test_strategy_checkpoint.py b/tests/ut/python/parallel/test_strategy_checkpoint.py index 49a319b4c67..a5402ec4601 100644 --- a/tests/ut/python/parallel/test_strategy_checkpoint.py +++ b/tests/ut/python/parallel/test_strategy_checkpoint.py @@ -75,7 +75,7 @@ def test_six_matmul_save(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt", + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"save_file": "./strategy_stage1.ckpt"}, group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch") strategy1 = ((8, 1), (1, 1)) strategy2 = ((1, 8), (8, 1)) @@ -137,7 +137,7 @@ def six_matmul_load(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt", + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_config={"load_file": "./strategy_stage1.ckpt"}, group_ckpt_save_file="./group_stage1.ckpt", dataset_strategy="full_batch") strategy1 = ((8, 1), (1, 1)) strategy3 = ((8, 1), (1, 1)) @@ -183,7 +183,7 @@ def test_six_matmul_save_auto(): self.matmul4 = P.MatMul() self.matmul5 = P.MatMul() self.matmul6 = P.MatMul() - self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") + self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1", requires_grad=False) self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") @@ -201,7 +201,9 @@ def test_six_matmul_save_auto(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.json") + set_auto_parallel_context(device_num=8, global_rank=0, + strategy_ckpt_config={"save_file": "./strategy_stage1_auto.json", + "only_trainable_params": True}) net = GradWrap(NetWithLoss(Net())) context.set_auto_parallel_context(parallel_mode="auto_parallel", dataset_strategy="full_batch") x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) @@ -256,7 +258,8 @@ def six_matmul_load_auto(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.json") + set_auto_parallel_context(device_num=8, global_rank=0, + strategy_ckpt_config={"load_file": "./strategy_stage1_auto.json"}) strategy1 = ((2, 2), (2, 2)) strategy3 = ((2, 2), (2, 2)) strategy4 = ((2, 2), (2, 2))