diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index 57f3d3ca7e4..bc9847dfdb9 100755 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -349,10 +349,14 @@ PipelinePair Deduplicate(const std::vector &node_vector, const FuncG auto manager = root->manager(); for (int64_t i = 0; i <= micro_max; ++i) { temp_vec.clear(); - for (auto &node : node_vector) { - auto node_micro = GetMicroBatch(node); - if (node_micro == i) { - temp_vec.push_back(node); + if (!root->has_flag(TRAINING)) { + temp_vec = node_vector; + } else { + for (auto &node : node_vector) { + auto node_micro = GetMicroBatch(node); + if (node_micro == i) { + temp_vec.push_back(node); + } } } if (temp_vec.size() <= 1) { @@ -568,10 +572,13 @@ void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { std::vector allreduce_params; GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params, &allreduce_params, root); - auto forward_end_cnode = forward_end.back()->cast(); - auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO); - MS_EXCEPTION_IF_NULL(micro_size); - auto micro_max = GetValue(micro_size); + int64_t micro_max = 0; + if (root->has_flag(TRAINING)) { + auto forward_end_cnode = forward_end.back()->cast(); + auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO); + MS_EXCEPTION_IF_NULL(micro_size); + micro_max = GetValue(micro_size); + } auto backward_start_pair = Deduplicate(backward_start, root, micro_max); auto backward_end_pair = Deduplicate(backward_end, root, micro_max); auto forward_start_pair = Deduplicate(forward_start, root, micro_max); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index fd287f3e272..5b4a0c77f59 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -87,6 +87,9 @@ ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micr } void PipelineTransformer::LabelMicroBatch() { + if (!root_->has_flag(TRAINING)) { + return; + } MS_EXCEPTION_IF_NULL(main_graph_); MS_EXCEPTION_IF_NULL(virtual_dataset_); auto node_user_map = manager_->node_users(); @@ -516,7 +519,11 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod std::vector depend_input = {NewValueNode(depend_op), parameter, send}; auto depend = main_graph_->NewCNode(depend_input); auto abstract = parameter->abstract(); + if (care_node) { + abstract = care_node->abstract(); + } depend->set_abstract(abstract); + send->set_abstract(abstract); SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; return send_out; } @@ -630,7 +637,7 @@ std::pair, std::vector> PipelineTransformer: std::vector send_ops; auto all_nodes = graph->nodes(); auto stage_num = g_device_manager->stage_num(); - if (stage_num > micro_size_) { + if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; } for (auto &node : all_nodes) { @@ -777,21 +784,7 @@ void PipelineTransformer::ElimParameter() { std::vector parameter_list; for (auto ¶meter : parameters) { if (!manager_->node_users()[parameter].empty()) { - if (!root_->has_flag(TRAINING)) { - for (auto &node_pair : manager_->node_users()[parameter]) { - auto user_node = node_pair.first; - if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) { - parameter_list.push_back(parameter); - break; - } - // remove_receive_inputs - auto cnode = user_node->cast(); - std::vector new_inputs = {cnode->input(0)}; - cnode->set_inputs(new_inputs); - } - } else { - parameter_list.push_back(parameter); - } + parameter_list.push_back(parameter); } } auto del_num = parameters.size() - parameter_list.size(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ff2b05f4715..7a29f6e10ad 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1621,7 +1621,8 @@ std::pair FindParallelCareNode(const AnfNodePtr &node, int3 MS_EXCEPTION_IF_NULL(prim_node_anf); PrimitivePtr node_prim = prim_node_anf->value()->cast(); MS_EXCEPTION_IF_NULL(node_prim); - if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) { + if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) || + IsPrimitiveCNode(cnode, prim::kPrimSend)) { continue; } if (IsParallelCareNode(cnode) && cnode->has_user_data()) { diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index 9204d04286c..f2a6a2d32f7 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -174,6 +174,31 @@ def test_pipeline_split_shared_parameter_stage1(): model = Model(net, optimizer=optimizer) model.train(2, dataset, dataset_sink_mode=False) + +def test_pipeline_split_shared_parameter_stage0_predict(): + context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, full_batch=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineSplit2(strategy1, strategy2) + model = Model(net) + model.predict(data, label) + + +def test_pipeline_split_shared_parameter_stage1_predict(): + context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, full_batch=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineSplit2(strategy1, strategy2) + model = Model(net) + model.predict(data, label) + + def test_pipeline_split_stage0_opt_shard(): context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") @@ -191,6 +216,7 @@ def test_pipeline_split_stage0_opt_shard(): assert param.name != "cell.block.1.param" assert param.name != "cell.block.1.param1" + def test_pipeline_split_stage1_opt_shard(): context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")