diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index f6dd9072c50..6fd84dd3641 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -343,7 +343,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & std::string strategy_key_name = ""; auto param_names = NodeParameterName(cnode); if (!param_names.empty()) { - strategy_key_name = param_names[0].first; + strategy_key_name = prim->name() + "_" + param_names[0].first; } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 2b495ec2810..3ccfbe3c26b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1528,7 +1528,7 @@ void ExtractInformation(const std::vector &all_nodes) { std::string strategy_key_name = ""; auto param_names = NodeParameterName(cnode); if (!param_names.empty()) { - strategy_key_name = param_names[0].first; + strategy_key_name = prim->name() + "_" + param_names[0].first; } bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); @@ -2219,9 +2219,23 @@ std::vector> NodeParameterName(const CNodePtr &node) auto input = node_inputs[i]; if (input->isa()) { auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - if (ParameterRequireGrad(input_parameter)) { - param_names.push_back({input_parameter->name(), i}); + if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { + param_names.push_back({input_parameter->name(), i}); + } + } else if (input->isa()) { + CNodePtr cnode = input->cast(); + if (!IsValueNode(cnode->input(0))) { + return param_names; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == CAST && cnode->inputs().size() >= 1) { + auto cast_input = cnode->inputs()[1]; + if (cast_input->isa()) { + auto cast_input_parameter = cast_input->cast(); + if (cast_input_parameter->has_default() && ParameterRequireGrad(cast_input_parameter)) { + param_names.push_back({cast_input_parameter->name(), i}); + } } } } @@ -2229,14 +2243,11 @@ std::vector> NodeParameterName(const CNodePtr &node) return param_names; } -void CheckpointStrategy(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); +void CheckpointStrategy(const std::vector &all_nodes) { MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; StrategyMap stra_map; TensorInfoMap tensor_info_map; ManualShapeMap manual_shape_map; - auto ret = func_graph->get_return(); - auto all_nodes = DeepScopedGraphSearch(ret); for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); @@ -2258,7 +2269,8 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { std::vector input_tensor_info = operator_info->inputs_tensor_info(); StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); - stra_map[param_name] = strategyPtr; + std::string stratey_key_name = prim->name() + "_" + param_name; + stra_map[stratey_key_name] = strategyPtr; for (auto param_name_pair : param_names) { if (param_name_pair.second - 1 >= UintToInt(input_tensor_info.size())) { continue; @@ -2552,7 +2564,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // save strategy as checkpoint for multi-train if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { - CheckpointStrategy(root); + CheckpointStrategy(all_nodes); } HandleSymbolicKeyInstance(root, all_nodes); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index ece29968711..a90d740d58c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -136,7 +136,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector> NodeParameterName(const CNodePtr &node); -void CheckpointStrategy(const FuncGraphPtr &func_graph); +void CheckpointStrategy(const std::vector &all_nodes); // main step of Parallel bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); diff --git a/model_zoo/official/recommend/wide_and_deep/README.md b/model_zoo/official/recommend/wide_and_deep/README.md index 1cc86e29bce..06084288855 100644 --- a/model_zoo/official/recommend/wide_and_deep/README.md +++ b/model_zoo/official/recommend/wide_and_deep/README.md @@ -152,7 +152,11 @@ optional arguments: --keep_prob The keep rate in dropout layer.(Default:1.0) --dropout_flag Enable dropout.(Default:0) --output_path Deprecated - --ckpt_path The location of the checkpoint file.(Defalut:./checkpoints/) + --ckpt_path The location of the checkpoint file. If the checkpoint file + is a slice of weight, multiple checkpoint files need to be + transferred. Use ';' to separate them and sort them in sequence + like "./checkpoints/0.ckpt;./checkpoints/1.ckpt". + (Defalut:./checkpoints/) --eval_file_name Eval output file.(Default:eval.og) --loss_file_name Loss output file.(Default:loss.log) --host_device_mix Enable host device mode or not.(Default:0) diff --git a/model_zoo/official/recommend/wide_and_deep/eval.py b/model_zoo/official/recommend/wide_and_deep/eval.py index 7f664f8abad..a139bbfacf5 100644 --- a/model_zoo/official/recommend/wide_and_deep/eval.py +++ b/model_zoo/official/recommend/wide_and_deep/eval.py @@ -18,7 +18,8 @@ import os from mindspore import Model, context -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint, load_param_into_net,\ + build_searched_strategy, merge_sliced_parameter from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack, EvalCallBack @@ -81,8 +82,28 @@ def test_eval(config): net_builder = ModelBuilder() train_net, eval_net = net_builder.get_net(config) - - param_dict = load_checkpoint(config.ckpt_path) + ckpt_path = config.ckpt_path + if ";" in ckpt_path: + ckpt_paths = ckpt_path.split(';') + param_list_dict = {} + strategy = build_searched_strategy(config.stra_ckpt) + for slice_path in ckpt_paths: + param_slice_dict = load_checkpoint(slice_path) + for key, value in param_slice_dict.items(): + if 'optimizer' in key: + continue + if key not in param_list_dict: + param_list_dict[key] = [] + param_list_dict[key].append(value) + param_dict = {} + for key, value in param_list_dict.items(): + if key in strategy: + merged_parameter = merge_sliced_parameter(value, strategy) + else: + merged_parameter = merge_sliced_parameter(value) + param_dict[key] = merged_parameter + else: + param_dict = load_checkpoint(ckpt_path) load_param_into_net(eval_net, param_dict) auc_metric = AUCMetric() diff --git a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py index 093307dfc73..40dba54578c 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py +++ b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py @@ -97,6 +97,7 @@ class EvalCallBack(Callback): self.eval_file_name = config.eval_file_name self.eval_values = [] self.host_device_mix = host_device_mix + self.config = config def epoch_end(self, run_context): """ @@ -106,7 +107,7 @@ class EvalCallBack(Callback): parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): context.set_auto_parallel_context(strategy_ckpt_save_file="", - strategy_ckpt_load_file="./strategy_train.ckpt") + strategy_ckpt_load_file=self.config.stra_ckpt) rank_id = 0 if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL, ParallelMode.DATA_PARALLEL): diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index a7d1035a10b..5464110cb1a 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -39,6 +39,8 @@ def argparse_init(): parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") parser.add_argument("--output_path", type=str, default="./output/") parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.") + parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt", + help="The strategy checkpoint file.") parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.") parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not") @@ -75,6 +77,7 @@ class WideDeepConfig(): self.eval_file_name = "eval.log" self.loss_file_name = "loss.log" self.ckpt_path = "./checkpoints/" + self.stra_ckpt = './checkpoints/strategy.ckpt' self.host_device_mix = 0 self.dataset_type = "tfrecord" self.parameter_server = 0 @@ -107,6 +110,7 @@ class WideDeepConfig(): self.eval_file_name = args.eval_file_name self.loss_file_name = args.loss_file_name self.ckpt_path = args.ckpt_path + self.stra_ckpt = args.stra_ckpt self.host_device_mix = args.host_device_mix self.dataset_type = args.dataset_type self.parameter_server = args.parameter_server diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 8010a843e7c..ae2b3f9ece8 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -203,6 +203,7 @@ class WideDeepModel(nn.Cell): self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) + self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size) self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, @@ -211,6 +212,10 @@ class WideDeepModel(nn.Cell): self.deep_reshape.add_prim_attr("skip_redistribution", True) self.reduce_sum.add_prim_attr("cross_batch", True) self.embedding_table = self.deep_embeddinglookup.embedding_table + elif host_device_mix: + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) + self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index a47b9e040e1..2ef0cf090f9 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -111,10 +111,11 @@ def train_and_eval(config): eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) callback = LossCallBack(config=config, per_print_times=20) - ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, + keep_checkpoint_max=5, integrated_save=False) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) - context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") + context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] if not host_device_mix: callback_list.append(ckpoint_cb) diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py index 578a4f93e33..4a2ed56e533 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py @@ -30,6 +30,8 @@ def argparse_init(): parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) parser.add_argument("--deep_layer_act", type=str, default='relu') parser.add_argument("--keep_prob", type=float, default=1.0) + parser.add_argument("--stra_ckpt", type=str, default="./strategy_train.ckpt", + help="The strategy checkpoint file.") parser.add_argument("--output_path", type=str, default="./output/") parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") @@ -63,6 +65,7 @@ class WideDeepConfig(): self.eval_file_name = "eval.log" self.loss_file_name = "loss.log" self.ckpt_path = "./checkpoints/" + self.stra_ckpt = "./strategy_train.ckpt" def argparse_init(self): """ @@ -90,3 +93,4 @@ class WideDeepConfig(): self.eval_file_name = args.eval_file_name self.loss_file_name = args.loss_file_name self.ckpt_path = args.ckpt_path + self.stra_ckpt = args.stra_ckpt