!18853 [AutoParallel]Pipeline support predict

Merge pull request !18853 from lichen/pipeline_support_predict_master
This commit is contained in:
i-robot 2021-06-26 08:02:59 +00:00 committed by Gitee
commit 2384cadeb9
4 changed files with 52 additions and 25 deletions

View File

@ -349,10 +349,14 @@ PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncG
auto manager = root->manager();
for (int64_t i = 0; i <= micro_max; ++i) {
temp_vec.clear();
for (auto &node : node_vector) {
auto node_micro = GetMicroBatch(node);
if (node_micro == i) {
temp_vec.push_back(node);
if (!root->has_flag(TRAINING)) {
temp_vec = node_vector;
} else {
for (auto &node : node_vector) {
auto node_micro = GetMicroBatch(node);
if (node_micro == i) {
temp_vec.push_back(node);
}
}
}
if (temp_vec.size() <= 1) {
@ -568,10 +572,13 @@ void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
std::vector<AnfNodePtr> allreduce_params;
GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
&allreduce_params, root);
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro_size);
auto micro_max = GetValue<int64_t>(micro_size);
int64_t micro_max = 0;
if (root->has_flag(TRAINING)) {
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
MS_EXCEPTION_IF_NULL(micro_size);
micro_max = GetValue<int64_t>(micro_size);
}
auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
auto forward_start_pair = Deduplicate(forward_start, root, micro_max);

View File

@ -87,6 +87,9 @@ ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micr
}
void PipelineTransformer::LabelMicroBatch() {
if (!root_->has_flag(TRAINING)) {
return;
}
MS_EXCEPTION_IF_NULL(main_graph_);
MS_EXCEPTION_IF_NULL(virtual_dataset_);
auto node_user_map = manager_->node_users();
@ -516,7 +519,11 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
auto depend = main_graph_->NewCNode(depend_input);
auto abstract = parameter->abstract();
if (care_node) {
abstract = care_node->abstract();
}
depend->set_abstract(abstract);
send->set_abstract(abstract);
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
return send_out;
}
@ -630,7 +637,7 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
std::vector<AnfNodePtr> send_ops;
auto all_nodes = graph->nodes();
auto stage_num = g_device_manager->stage_num();
if (stage_num > micro_size_) {
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
}
for (auto &node : all_nodes) {
@ -777,21 +784,7 @@ void PipelineTransformer::ElimParameter() {
std::vector<AnfNodePtr> parameter_list;
for (auto &parameter : parameters) {
if (!manager_->node_users()[parameter].empty()) {
if (!root_->has_flag(TRAINING)) {
for (auto &node_pair : manager_->node_users()[parameter]) {
auto user_node = node_pair.first;
if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) {
parameter_list.push_back(parameter);
break;
}
// remove_receive_inputs
auto cnode = user_node->cast<CNodePtr>();
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
cnode->set_inputs(new_inputs);
}
} else {
parameter_list.push_back(parameter);
}
parameter_list.push_back(parameter);
}
}
auto del_num = parameters.size() - parameter_list.size();

View File

@ -1621,7 +1621,8 @@ std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int3
MS_EXCEPTION_IF_NULL(prim_node_anf);
PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
IsPrimitiveCNode(cnode, prim::kPrimSend)) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {

View File

@ -174,6 +174,31 @@ def test_pipeline_split_shared_parameter_stage1():
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_shared_parameter_stage0_predict():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, full_batch=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2)
model = Model(net)
model.predict(data, label)
def test_pipeline_split_shared_parameter_stage1_predict():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, full_batch=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2)
model = Model(net)
model.predict(data, label)
def test_pipeline_split_stage0_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
@ -191,6 +216,7 @@ def test_pipeline_split_stage0_opt_shard():
assert param.name != "cell.block.1.param"
assert param.name != "cell.block.1.param1"
def test_pipeline_split_stage1_opt_shard():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")