forked from mindspore-Ecosystem/mindspore
!18853 [AutoParallel]Pipeline support predict
Merge pull request !18853 from lichen/pipeline_support_predict_master
This commit is contained in:
commit
2384cadeb9
|
@ -349,10 +349,14 @@ PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncG
|
|||
auto manager = root->manager();
|
||||
for (int64_t i = 0; i <= micro_max; ++i) {
|
||||
temp_vec.clear();
|
||||
for (auto &node : node_vector) {
|
||||
auto node_micro = GetMicroBatch(node);
|
||||
if (node_micro == i) {
|
||||
temp_vec.push_back(node);
|
||||
if (!root->has_flag(TRAINING)) {
|
||||
temp_vec = node_vector;
|
||||
} else {
|
||||
for (auto &node : node_vector) {
|
||||
auto node_micro = GetMicroBatch(node);
|
||||
if (node_micro == i) {
|
||||
temp_vec.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (temp_vec.size() <= 1) {
|
||||
|
@ -568,10 +572,13 @@ void Reorder(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
|||
std::vector<AnfNodePtr> allreduce_params;
|
||||
GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
|
||||
&allreduce_params, root);
|
||||
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
|
||||
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
|
||||
MS_EXCEPTION_IF_NULL(micro_size);
|
||||
auto micro_max = GetValue<int64_t>(micro_size);
|
||||
int64_t micro_max = 0;
|
||||
if (root->has_flag(TRAINING)) {
|
||||
auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
|
||||
auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
|
||||
MS_EXCEPTION_IF_NULL(micro_size);
|
||||
micro_max = GetValue<int64_t>(micro_size);
|
||||
}
|
||||
auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
|
||||
auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
|
||||
auto forward_start_pair = Deduplicate(forward_start, root, micro_max);
|
||||
|
|
|
@ -87,6 +87,9 @@ ValuePtr PipelineTransformer::SetMicroBatch(const AnfNodePtr &node, int64_t micr
|
|||
}
|
||||
|
||||
void PipelineTransformer::LabelMicroBatch() {
|
||||
if (!root_->has_flag(TRAINING)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(main_graph_);
|
||||
MS_EXCEPTION_IF_NULL(virtual_dataset_);
|
||||
auto node_user_map = manager_->node_users();
|
||||
|
@ -516,7 +519,11 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
|
|||
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
|
||||
auto depend = main_graph_->NewCNode(depend_input);
|
||||
auto abstract = parameter->abstract();
|
||||
if (care_node) {
|
||||
abstract = care_node->abstract();
|
||||
}
|
||||
depend->set_abstract(abstract);
|
||||
send->set_abstract(abstract);
|
||||
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
|
||||
return send_out;
|
||||
}
|
||||
|
@ -630,7 +637,7 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
std::vector<AnfNodePtr> send_ops;
|
||||
auto all_nodes = graph->nodes();
|
||||
auto stage_num = g_device_manager->stage_num();
|
||||
if (stage_num > micro_size_) {
|
||||
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
|
||||
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
|
@ -777,21 +784,7 @@ void PipelineTransformer::ElimParameter() {
|
|||
std::vector<AnfNodePtr> parameter_list;
|
||||
for (auto ¶meter : parameters) {
|
||||
if (!manager_->node_users()[parameter].empty()) {
|
||||
if (!root_->has_flag(TRAINING)) {
|
||||
for (auto &node_pair : manager_->node_users()[parameter]) {
|
||||
auto user_node = node_pair.first;
|
||||
if (!IsPrimitiveCNode(user_node, prim::kPrimReceive)) {
|
||||
parameter_list.push_back(parameter);
|
||||
break;
|
||||
}
|
||||
// remove_receive_inputs
|
||||
auto cnode = user_node->cast<CNodePtr>();
|
||||
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)};
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
} else {
|
||||
parameter_list.push_back(parameter);
|
||||
}
|
||||
parameter_list.push_back(parameter);
|
||||
}
|
||||
}
|
||||
auto del_num = parameters.size() - parameter_list.size();
|
||||
|
|
|
@ -1621,7 +1621,8 @@ std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int3
|
|||
MS_EXCEPTION_IF_NULL(prim_node_anf);
|
||||
PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
|
||||
if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
|
||||
IsPrimitiveCNode(cnode, prim::kPrimSend)) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
||||
|
|
|
@ -174,6 +174,31 @@ def test_pipeline_split_shared_parameter_stage1():
|
|||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_pipeline_split_shared_parameter_stage0_predict():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, full_batch=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineSplit2(strategy1, strategy2)
|
||||
model = Model(net)
|
||||
model.predict(data, label)
|
||||
|
||||
|
||||
def test_pipeline_split_shared_parameter_stage1_predict():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, full_batch=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
strategy2 = ((2, 1), (1, 1))
|
||||
net = PipelineSplit2(strategy1, strategy2)
|
||||
model = Model(net)
|
||||
model.predict(data, label)
|
||||
|
||||
|
||||
def test_pipeline_split_stage0_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
@ -191,6 +216,7 @@ def test_pipeline_split_stage0_opt_shard():
|
|||
assert param.name != "cell.block.1.param"
|
||||
assert param.name != "cell.block.1.param1"
|
||||
|
||||
|
||||
def test_pipeline_split_stage1_opt_shard():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
|
|
Loading…
Reference in New Issue