!65761 pipeline parallel inference

Merge pull request !65761 from chenweifeng/feature-2.3-pipeline-inference
This commit is contained in:
i-robot 2024-03-09 04:47:38 +00:00 committed by Gitee
commit 0aa21ef058
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 668 additions and 31 deletions

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include <unordered_map>
#include "ir/func_graph.h"
@ -28,6 +29,7 @@
#include "frontend/parallel/tensor_layout/tensor_layout.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/parameter_manager.h"
#include "frontend/parallel/tensor_layout/shared_parameter.h"
namespace mindspore {
namespace parallel {
@ -226,6 +228,23 @@ py::dict GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr &sub_graph, const F
}
return cnode_info_dict;
}
std::tuple<bool, bool, int64_t, int64_t> GetSharedParameterInfo(const AnfNodePtr &param) {
MS_EXCEPTION_IF_NULL(param);
bool is_pipeline_shared = false;
bool is_send = false;
int64_t peer_rank = 0;
int64_t sr_tag = 0;
auto shared_params = param->user_data<parallel::SharedParameter>();
if (shared_params) {
is_pipeline_shared = shared_params->pipeline_shared();
is_send = shared_params->is_send();
peer_rank = shared_params->peer_rank();
sr_tag = shared_params->sr_tag();
}
return std::tuple(is_pipeline_shared, is_send, peer_rank, sr_tag);
}
} // namespace
py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
@ -261,8 +280,10 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split();
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
py::tuple layout =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
auto [is_pipeline_shared, is_send, peer_rank, sr_tag] = GetSharedParameterInfo(para);
py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
for (auto &name : names) {
dict[py::str(name)] = layout;
}
@ -285,8 +306,12 @@ py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
int64_t field_size = layout->get_field_size();
bool uniform_split = layout->get_uniform_split();
const std::string &opt_shard_group = layout->get_opt_shard_group();
py::tuple layout_tuple =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
bool is_pipeline_shared = layout->pipeline_shared();
bool is_send = layout->is_send();
int64_t peer_rank = layout->peer_rank();
int64_t sr_tag = layout->sr_tag();
py::tuple layout_tuple = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
dict[py::str(name)] = layout_tuple;
}
return dict;

View File

@ -73,7 +73,7 @@ void SpreadFineGrainedInterleavedIndexForForwardCommNodes(const CNodePtr &cnode,
if (pre_cnode == nullptr) {
continue;
}
// BFS end search contition
// BFS end search condition
if (IsPrimitiveCNode(pre_cnode, prim::kPrimStridedSlice) &&
GetCNodePrimitive(pre_cnode)->HasAttr(kAttrFineGrainedInterleavedBlockIndex)) {
pre_cnode->AddAttr("fine_grained_interleaved_border", MakeValue<size_t>(0));
@ -219,6 +219,10 @@ void LabelFineGrainedInterleavedIndex(const FuncGraphPtr &graph) {
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (!IsTraining(manager)) {
return;
}
FuncGraphPtr forward_graph = graph;
FuncGraphPtr backward_graph = graph;
auto context = MsContext::GetInstance();

View File

@ -40,6 +40,7 @@
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/step_parallel_utils.h"
#include "frontend/parallel/graph_util/graph_splitter.h"
#include "frontend/parallel/tensor_layout/shared_parameter.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
#include "include/common/utils/comm_manager.h"
@ -71,6 +72,45 @@ static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
return node->abstract();
}
void PipelineTransformer::UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op,
bool is_send) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(communcate_op);
if (!node->isa<Parameter>()) {
return;
}
auto root_param = node;
if (node->func_graph() != root_) {
root_param = GetArgumentsByParameter(node);
MS_EXCEPTION_IF_NULL(root_param);
}
// get communication info from cnode.
auto prim = GetCNodePrimitive(communcate_op);
MS_EXCEPTION_IF_NULL(prim);
auto sr_tag_attr = prim->GetAttr(SR_TAG);
MS_EXCEPTION_IF_NULL(sr_tag_attr);
auto sr_tag = GetValue<int64_t>(sr_tag_attr);
auto peer_rank_attr = is_send ? prim->GetAttr(DEST_RANK) : prim->GetAttr(SRC_RANK);
MS_EXCEPTION_IF_NULL(peer_rank_attr);
auto peer_rank = GetValue<int64_t>(peer_rank_attr);
auto group_attr = prim->GetAttr(GROUP);
MS_EXCEPTION_IF_NULL(group_attr);
auto group = GetValue<std::string>(group_attr);
// Use global rank since local group may not exist after loading checkpoint.
auto rank_list = g_device_manager->FindRankListByHashName(group);
peer_rank = rank_list.at(peer_rank);
// update tensor layout.
auto param = root_param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto shared_parameters = std::make_shared<SharedParameter>(true, is_send, peer_rank, sr_tag);
param->set_user_data<SharedParameter>(shared_parameters);
}
TensorInfo PipelineTransformer::GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param) {
if (is_param) {
auto inputs_tensor_info = op_info_pair.first->inputs_tensor_info();
@ -310,7 +350,7 @@ size_t PipelineTransformer::GetBatchAxisForInput(const AnfNodeIndexSet &input_no
}
}
}
if (batch_axis_count != kSizeOne) {
if (is_train_ && batch_axis_count != kSizeOne) {
MS_LOG(EXCEPTION)
<< "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
<< "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
@ -1112,7 +1152,9 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
}
(void)parameter_color_map_[root_param].insert(user_stage);
auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
auto recv_node = InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
UpdateParameterSharedInfo(root_param, recv_node, false);
return recv_node;
}
// insert send
if (Reuse(argument, user_stage, ops, DEST_RANK)) {
@ -1121,6 +1163,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
auto send_out = InsertSend(argument, user_stage, stage_, micro);
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
UpdateParameterSharedInfo(argument, send_out.depend, true);
return send_out.depend;
}
@ -1374,7 +1417,7 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchSend(const AnfNodePtr &node, b
for (auto &user : shared_cell_users_) {
auto cuser = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cuser);
auto value = cuser->GetPrimalAttr(MICRO);
auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
MS_EXCEPTION_IF_NULL(value);
send_input = single_pipeline_end ? user : CreateTupleGetItemNode(main_graph_, user, end_index);
(void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
@ -1388,6 +1431,11 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
SeparateParamBorder(nodes, true, &pipeline_params, &pipeline_ends);
std::vector<AnfNodePtr> sends;
SetNodeAbstract(pipeline_ends);
// Create root graph output before modify subgraph(shared cell).
// This process order is crucial when the output of subgraph is directly used as root graph.
auto zero_outputs = GetZeroOutputs(main_graph_);
size_t ends_size = pipeline_ends.size();
bool single_pipeline_end = ends_size == 1;
if (single_pipeline_end) {
@ -1408,7 +1456,9 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
}
for (auto &node : pipeline_params) {
auto params = FetchSend(node, true, false, 0);
(void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
if (is_train_) {
(void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
}
}
for (size_t i = 0; i < ends_size; i++) {
auto node = pipeline_ends[i];
@ -1416,7 +1466,6 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
(void)std::copy(ends.begin(), ends.end(), std::back_inserter(sends));
}
auto make_tuple = CreateMakeTupleNode(main_graph_, sends);
auto zero_outputs = GetZeroOutputs(main_graph_);
std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
auto out_node = main_graph_->NewCNode(out);
out_node->set_abstract(zero_outputs->abstract());
@ -1486,7 +1535,11 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, b
recv_input = user->input(input_pos);
recv = GenNewRecvFromOld(node, recv_input, value);
for (auto &share_user : shared_cell_users_) {
manager_->SetEdge(share_user, input_pos, recv);
if (is_train_) {
manager_->SetEdge(share_user, input_pos, recv);
} else {
manager_->SetEdge(share_user, input_pos, recv_input);
}
}
node->set_user_data<bool>(ORIGIN_INPUT_IS_PARAM, std::make_shared<bool>(true));
} else {
@ -1501,7 +1554,7 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, b
for (auto &user : shared_cell_users_) {
auto cuser = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cuser);
auto value = cuser->GetPrimalAttr(MICRO);
auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
MS_EXCEPTION_IF_NULL(value);
if (enable_share_cell_ || !is_train_) {
auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
@ -1549,6 +1602,7 @@ void PipelineTransformer::ResetSharedCellParamAndArgu(
MS_LOG(DEBUG) << "The shared cell origin params size is " << params.size() << ", new params size is "
<< new_params.size();
manager_->SetParameters(shared_cell_, new_params);
shared_cell_->set_fv_param_count(new_params.size());
// set call inputs
size_t user_index = 0;
for (auto &user : shared_cell_users_) {

View File

@ -118,6 +118,7 @@ class PipelineTransformer {
bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> &parameters,
const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage);
size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const;
void UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op, bool is_send);
FuncGraphManagerPtr manager_;
int64_t stage_;
FuncGraphPtr root_;

View File

@ -3025,6 +3025,50 @@ static void MoveMicroMirrorOutCallFunc(const FuncGraphPtr &root) {
}
}
static void BroadcastMultiOutputs(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, const Group &group) {
auto output = root->get_return()->input(1)->cast<CNodePtr>();
auto output_abstract = output->abstract();
MS_EXCEPTION_IF_NULL(output_abstract);
auto abstract_tuple = output_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
auto abstract_list = abstract_tuple->elements();
AnfNodePtrList make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < abstract_list.size(); i++) {
auto abstract = abstract_list[i];
MS_EXCEPTION_IF_NULL(abstract);
// TupleGetItem
auto idx = NewValueNode(SizeToLong(i));
CNodePtr tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_abstract(abstract);
// Depend: prevent disorder and CSE
if (i > 0) {
tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimDepend), tuple_getitem, make_tuple_input[i]});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_abstract(abstract);
}
// Allreduce
CNodePtr allreduce = root->NewCNode({NewValueNode(prim::kPrimAllReduce), tuple_getitem});
MS_EXCEPTION_IF_NULL(allreduce);
allreduce->set_abstract(abstract);
common::AnfAlgo::SetNodeAttr(OP, MakeValue(REDUCE_OP_SUM), allreduce);
common::AnfAlgo::SetNodeAttr(GROUP, MakeValue(group.name()), allreduce);
// Disable GE allreduce fusion.
common::AnfAlgo::SetNodeAttr(FUSION, MakeValue(static_cast<int64_t>(0)), allreduce);
make_tuple_input.push_back(allreduce);
}
CNodePtr make_tuple_node = root->NewCNode(make_tuple_input);
MS_EXCEPTION_IF_NULL(make_tuple_node);
make_tuple_node->set_abstract(abstract_tuple);
(void)manager->Replace(output, make_tuple_node);
}
static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
auto pipeline_result_broadcast = parallel::ParallelContext::GetInstance()->pipeline_result_broadcast();
@ -3032,12 +3076,6 @@ static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManager
return;
}
auto return_node = root->get_return();
const auto &abstract = return_node->abstract();
if (abstract->isa<abstract::AbstractTuple>()) {
return;
}
std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage();
Group group;
if (g_device_manager->CreateGroup(rank_list, &group) != SUCCESS) {
@ -3045,6 +3083,12 @@ static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManager
<< rank_list;
}
auto return_node = root->get_return();
const auto &abstract = return_node->abstract();
if (abstract->isa<abstract::AbstractTuple>()) {
return BroadcastMultiOutputs(root, manager, group);
}
InsertAllReduceToNodeInput(return_node, group.name(), PARALLEL_RESULT_BROADCAST);
return_node->input(1)->set_abstract(abstract);
}

View File

@ -0,0 +1,54 @@
/**
* Copyright 2024 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_
#include <string>
namespace mindspore {
namespace parallel {
class SharedParameter {
public:
SharedParameter(bool pipeline_shared, bool is_send, int64_t peer_rank, int64_t sr_tag)
: pipeline_shared_(pipeline_shared), is_send_(is_send), peer_rank_(peer_rank), sr_tag_(sr_tag) {}
~SharedParameter() = default;
void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; }
bool pipeline_shared() const { return pipeline_shared_; }
void set_is_send(bool is_send) { is_send_ = is_send; }
bool is_send() const { return is_send_; }
void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; }
int64_t peer_rank() const { return peer_rank_; }
void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; }
int64_t sr_tag() const { return sr_tag_; }
// Key for user data.
constexpr static char key[] = "SharedParameter";
private:
bool pipeline_shared_ = false;
bool is_send_ = false;
int64_t peer_rank_{0};
int64_t sr_tag_{0};
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_

View File

@ -28,6 +28,7 @@
#include "utils/system/sha256.h"
#include "include/common/utils/utils.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/tensor_layout/shared_parameter.h"
#include "mindspore/core/utils/file_utils.h"
#if defined(__linux__) && defined(WITH_BACKEND)
@ -76,6 +77,13 @@ void BuildLayout(const FuncGraphPtr &func_graph, mind_ir::ModelProto *model) {
layoutProto->set_field_size(field_size);
layoutProto->set_uniform_split(uniform_split);
layoutProto->set_opt_shard_group(opt_shard_group);
auto shared_param = para->user_data<parallel::SharedParameter>();
if (shared_param) {
layoutProto->set_pipeline_shared(shared_param->pipeline_shared());
layoutProto->set_is_send(shared_param->is_send());
layoutProto->set_peer_rank(shared_param->peer_rank());
layoutProto->set_sr_tag(shared_param->sr_tag());
}
}
}
}

View File

@ -2299,6 +2299,18 @@ const LayoutMap MSANFModelParser::ParseLayout(const mind_ir::ModelProto &model_p
cur_layout->set_uniform_split(uniform_spilt);
cur_layout->set_opt_shard_group(opt_shard_group);
// Check optional field for backward compatibility.
if (layout_proto.has_pipeline_shared()) {
bool pipeline_shared = layout_proto.pipeline_shared();
bool is_send = layout_proto.is_send();
int64_t peer_rank = layout_proto.peer_rank();
int64_t sr_tag = layout_proto.sr_tag();
cur_layout->set_pipeline_shared(pipeline_shared);
cur_layout->set_is_send(is_send);
cur_layout->set_peer_rank(peer_rank);
cur_layout->set_sr_tag(sr_tag);
}
ret[name] = cur_layout;
}
return ret;

View File

@ -42,6 +42,14 @@ class Layout {
void set_uniform_split(bool uniform_split) { uniform_split_ = uniform_split; }
const std::string &get_opt_shard_group() const { return opt_shard_group_; }
void set_opt_shard_group(const std::string &opt_shard_group) { opt_shard_group_ = opt_shard_group; }
void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; }
bool pipeline_shared() const { return pipeline_shared_; }
void set_is_send(bool is_send) { is_send_ = is_send; }
bool is_send() const { return is_send_; }
void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; }
int64_t peer_rank() const { return peer_rank_; }
void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; }
int64_t sr_tag() const { return sr_tag_; }
private:
std::vector<int64_t> device_arrangement_{};
@ -50,6 +58,11 @@ class Layout {
int64_t field_size_ = 0;
bool uniform_split_ = false;
std::string opt_shard_group_ = "";
// pipeline stage shared param info
bool pipeline_shared_ = false;
bool is_send_ = false;
int64_t peer_rank_{0};
int64_t sr_tag_{0};
};
using LayoutPtr = std::shared_ptr<Layout>;
using LayoutMap = std::map<string, LayoutPtr>;

View File

@ -233,6 +233,10 @@ message LayoutProto {
optional int64 field_size = 5;
optional bool uniform_split = 6;
optional string opt_shard_group = 7;
optional bool pipeline_shared = 8;
optional bool is_send = 9;
optional int64 peer_rank = 10;
optional int64 sr_tag = 11;
}
message PrimitiveProto {

View File

@ -18,8 +18,9 @@ from __future__ import absolute_import
from mindspore.parallel.algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters
from mindspore.parallel.checkpoint_transform import rank_list_for_transform, transform_checkpoint_by_rank, \
transform_checkpoints, merge_pipeline_strategys
transform_checkpoints, merge_pipeline_strategys, sync_pipeline_shared_parameters
from mindspore.parallel.shard import shard
__all__ = ["set_algo_parameters", "reset_algo_parameters", "get_algo_parameters", "rank_list_for_transform",
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard"]
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard",
"sync_pipeline_shared_parameters"]

View File

@ -22,6 +22,7 @@ from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
@ -29,7 +30,7 @@ from mindspore.parallel._parallel_serialization import _rank_list_for_transform_
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
"transform_checkpoints"]
"transform_checkpoints", "sync_pipeline_shared_parameters"]
def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
@ -336,3 +337,121 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
del param_total_dict_copy
del param_total_dict
def _sync_params(name, param, layout):
"""synchronize single parameter"""
if len(layout) < 10:
ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name)
return
pipeline_shared = layout[6]
if not pipeline_shared:
return
is_send = layout[7]
peer_rank = layout[8]
sr_tag = layout[9]
class SharedParameterSyncCell(ms.nn.Cell):
"""synchronize cell"""
def __init__(self, param, is_send, peer_rank, sr_tag):
super().__init__()
self.param = param
self.is_send = is_send
self.ret = ms.Tensor([0])
from mindspore.ops.operations._inner_ops import Send, Receive
if self.is_send:
self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank)
else:
self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype)
def construct(self):
if self.is_send:
out = self.send(self.param)
return ms.ops.functional.depend(self.ret, out)
self.param = self.receive(self.ret)
return ms.ops.functional.depend(self.ret, self.param)
sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag)
sync_net()
def sync_pipeline_shared_parameters(net):
"""synchronize pipeline parallel stage shared parameters.
Parameters may be shared between different stages in pipeline parallel inference. For example, `embedding table` is
shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
to perform synchronization after `embedding table` changes.
Note:
The network should be compiled before synchronize pipeline parallel stage shared parameters.
Args:
net (nn.Cell): the inference network.
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn, Parameter, Tensor
>>> class VocabEmbedding(nn.Cell):
... def __init__(self, vocab_size, embedding_size):
... super().__init__()
... self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size])), name='embedding')
...
... def construct(self, x):
... output = self.gather(self.embedding_table, x, 0)
... return output, self.embedding_table.value()
...
>>> class LMHead(nn.Cell):
... def __init__(self):
... super().__init__()
... self.matmul = ops.MatMul(transpose_b=True)
...
... def construct(self, state, embed):
... state = state.reshape(-1, state.shape[-1])
... return self.matmul(state, embed)
...
>>> class Network(nn.Cell):
... @lazy_inline
... def __init__(self):
... super().__init__()
... self.word_embedding = VocabEmbedding(vocab_size=4, embedding_size=4)
... self.head = LMHead()
...
... def construct(self, x):
... x, embed = self.word_embedding(x)
... x = self.head(x, embed)
... return x
>>>
>>> net = Network()
>>> net.word_embedding.pipeline_stage = 0
>>> net.head.pipeline_stage = 1
>>> x = Tensor(np.ones((8, 4))
>>> net.compile()
>>> ms.parallel.sync_pipeline_shared_parameters(net)
>>> print(net.word_embedding.embedding_table.asnumpy())
>>> [[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
"""
layout_dict = net.parameter_layout_dict
if _is_in_auto_parallel_mode() and not layout_dict:
from mindspore.common.api import _get_parameter_layout
layout_dict = _get_parameter_layout()
# switch to standalone mode
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
full_batch = ms.context.get_auto_parallel_context("full_batch")
ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
# synchronize shared parameter
for name, param in net.parameters_and_names():
if name in layout_dict:
_sync_params(name, param, layout_dict[name])
# restore parallel context
ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)

View File

@ -63,6 +63,7 @@ from mindspore.parallel._parallel_serialization import _convert_to_list, _conver
_restore_group_info_list
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
_store_warm_up_ptr_by_tensor_list, _cache_enable
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
from mindspore.train._utils import read_proto
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
split_mindir, split_dynamic_mindir
@ -559,6 +560,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
"""Convert nn.Cell to param_list."""
sync_pipeline_shared_parameters(save_obj)
param_list = []
parameter_layout_dict = save_obj.parameter_layout_dict
if _is_in_auto_parallel_mode() and not parameter_layout_dict:

View File

@ -0,0 +1,263 @@
import numpy as np
import mindspore
import mindspore.communication.management as D
from mindspore import lazy_inline, context, nn, Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
pipeline_stages = 4
class FC(nn.Cell):
def __init__(self, shape):
super().__init__()
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
self.matmul = P.MatMul()
def construct(self, x):
return self.matmul(x, self.w)
class WordEmbedding(nn.Cell):
def __init__(self, shape):
super().__init__()
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
self.matmul = P.MatMul()
def construct(self, x):
return self.matmul(x, self.w), self.w
class LMHead(nn.Cell):
def __init__(self, shape):
super().__init__()
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
def construct(self, x, w):
x = self.matmul1(x, self.w)
return self.matmul2(x, w), x + x
class Net(nn.Cell):
@lazy_inline
def __init__(self):
super().__init__()
shape = (8, 8)
self.word_embedding = WordEmbedding(shape)
self.decoder1 = FC(shape)
self.decoder2 = FC(shape)
self.lm_head = LMHead(shape)
self.word_embedding.matmul.shard(((1, 1), (1, 1)))
self.decoder1.matmul.shard(((1, 1), (1, 1)))
self.decoder2.matmul.shard(((1, 1), (1, 1)))
self.lm_head.matmul1.shard(((1, 1), (1, 1)))
self.lm_head.matmul2.shard(((1, 1), (1, 1)))
self.word_embedding.pipeline_stage = 0
self.decoder1.pipeline_stage = 1
self.decoder2.pipeline_stage = 2
self.lm_head.pipeline_stage = 3
def construct(self, x):
x, w = self.word_embedding(x)
x = self.decoder1(x)
x = self.decoder2(x)
x, y = self.lm_head(x, w)
return x, y
class PipelineCellInference(nn.Cell):
def __init__(self, network, micro_batch_num):
super().__init__()
self.network = network
self.micro_batch_num = micro_batch_num
self.concat = P.Concat()
def construct(self, x):
ret_x = ()
ret_y = ()
for i in range(self.micro_batch_num):
micro_batch_size = x.shape[0] // self.micro_batch_num
start = micro_batch_size * i
end = micro_batch_size * (i + 1)
micro_input = x[start:end]
ret1, ret2 = self.network(micro_input)
ret_x = ret_x + (ret1,)
ret_y = ret_y + (ret2,)
ret_x = self.concat(ret_x)
ret_y = self.concat(ret_y)
return ret_x, ret_y
def get_stage_id():
rank_size = D.get_group_size()
rank_id = D.get_rank()
rank_per_stage = rank_size // pipeline_stages
return rank_id // rank_per_stage
def test_pipeline_inference_basic():
"""
Feature: Pipeline parallel inference
Description: Micro batch split
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages)
net = PipelineCellInference(Net(), micro_batch_num=2)
net.set_train(False)
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
ret = net(x)
expect = [[np.zeros(shape, np.float32), np.zeros(shape, np.float32)],
[np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]]
is_last_stage = get_stage_id() == pipeline_stages - 1
assert np.allclose(ret[0].asnumpy(), expect[is_last_stage][0])
assert np.allclose(ret[1].asnumpy(), expect[is_last_stage][1])
def test_pipeline_inference_broadcast():
"""
Feature: Pipeline parallel inference
Description: Broadcast last stage result, multi-output.
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
net = PipelineCellInference(Net(), micro_batch_num=4)
net.set_train(False)
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
ret = net(x)
print(ret)
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
assert np.allclose(ret[0].asnumpy(), expect[0])
assert np.allclose(ret[1].asnumpy(), expect[1])
def test_pipeline_inference_single_micro_batch():
"""
Feature: Pipeline parallel inference
Description: Broadcast last stage result, without micro batch split
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
net = PipelineCellInference(Net(), micro_batch_num=1)
net.set_train(False)
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
ret = net(x)
print(ret)
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
assert np.allclose(ret[0].asnumpy(), expect[0])
assert np.allclose(ret[1].asnumpy(), expect[1])
def test_pipeline_inference_without_wrapper():
"""
Feature: Pipeline parallel inference
Description: Broadcast last stage result, without wrapper
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
net = Net()
net.set_train(False)
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
ret = net(x)
print(ret)
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
assert np.allclose(ret[0].asnumpy(), expect[0])
assert np.allclose(ret[1].asnumpy(), expect[1])
class PipelineCellInferenceSingleOutput(nn.Cell):
def __init__(self, network, micro_batch_num):
super().__init__()
self.network = network
self.micro_batch_num = micro_batch_num
self.concat = P.Concat()
def construct(self, x):
ret = ()
for i in range(self.micro_batch_num):
micro_batch_size = x.shape[0] // self.micro_batch_num
start = micro_batch_size * i
end = micro_batch_size * (i + 1)
micro_input = x[start:end]
ret1, _ = self.network(micro_input)
ret = ret + (ret1,)
ret = self.concat(ret)
return ret
def test_pipeline_inference_single_output():
"""
Feature: Pipeline parallel inference
Description: Micro batch split
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
net = PipelineCellInferenceSingleOutput(Net(), micro_batch_num=2)
net.set_train(False)
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
ret = net(x)
print(ret)
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 5)]
is_last_stage = get_stage_id() == pipeline_stages - 1
assert np.allclose(ret.asnumpy(), expect[is_last_stage])
def test_pipeline_inference_shared_params():
"""
Feature: Pipeline parallel inference
Description: Shared parameters synchronize
Expectation: success
"""
D.init()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
net = Net()
net.set_train(False)
if get_stage_id() == pipeline_stages - 1:
shape, dtype = net.word_embedding.w.shape, net.word_embedding.w.dtype
net.word_embedding.w.set_data(Tensor(np.zeros(shape), dtype))
shape = (8, 8)
x = Tensor(np.ones(shape), mindspore.float32)
# compile and synchronize
net.compile(x)
sync_pipeline_shared_parameters(net)
ret = net(x)
print(ret)
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
assert np.allclose(ret[0].asnumpy(), expect[0])
assert np.allclose(ret[1].asnumpy(), expect[1])

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_msrun_pipeline_parallel_inference():
'''
Feature: Pipeline inference optimizer parallel.
Description: Test pipeline inference.
Expectation: Run success.
'''
ret = os.system("msrun --worker_num=4 --local_worker_num=4 --master_addr=127.0.0.1 --master_port=10969 "
"--join=True --log_dir=./pipeline_inference_logs pytest -s -v "
"pipeline_inference.py")
assert ret == 0

View File

@ -57,7 +57,8 @@ def test_get_parameter_layout():
weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict
assert net.parameter_layout_dict["x"][0:6] == expect_dict["x"]
assert net.parameter_layout_dict["w1"][0:6] == expect_dict["w1"]
if __name__ == '__main__':

View File

@ -111,9 +111,8 @@ def test_pipeline_inference_first_stage():
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
phase = compile_infer_net(net, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs_has('Send-0', ['network.fc0.weight'], graph_id=0)
assert validator.check_node_inputs_has('Send-1', ['pipeline_inference_SimpleNet_construct'], graph_id=0)
assert validator.check_node_inputs_has('Send-2', ['pipeline_inference_SimpleNet_construct'], graph_id=0)
assert validator.check_node_inputs_has('Send-0', ['pipeline_inference_SimpleNet_construct'], graph_id=1)
assert validator.check_node_inputs_has('Send-1', ['pipeline_inference_SimpleNet_construct'], graph_id=1)
def test_pipeline_inference_last_stage():
@ -133,8 +132,7 @@ def test_pipeline_inference_last_stage():
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
phase = compile_infer_net(net, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs_has('Receive-0', ['network.fc0.weight'], graph_id=1)
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'Receive-0', 'Receive-2'],
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'network.fc0.weight', 'Receive-1'],
graph_id=1)
@ -156,7 +154,6 @@ def test_pipeline_inference_result_broadcast():
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
phase = compile_infer_net(net, x)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs_has('Receive-0', ['network.fc0.weight'], graph_id=1)
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'Receive-0', 'Receive-2'],
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'network.fc0.weight', 'Receive-1'],
graph_id=1)
assert validator.check_node_inputs_has('AllReduce-0', ['Depend-0'], graph_id=1)

View File

@ -135,7 +135,10 @@ def test_grad_sens_parameter_type():
b_layout = ([64], [-1, -1], [64, 64], 0, True, '')
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
assert net.parameter_layout_dict == expect_dict
assert net.parameter_layout_dict['x'][0:6] == expect_dict['x']
assert net.parameter_layout_dict['y'][0:6] == expect_dict['y']
assert net.parameter_layout_dict['b'][0:6] == expect_dict['b']
assert net.parameter_layout_dict['sens'][0:6] == expect_dict['sens']
def test_grad_sens_tensor_type():

View File

@ -73,7 +73,7 @@ class ParallelValidator:
if param_name not in self._parameter_layout_dict.keys():
return False
return self._parameter_layout_dict[param_name] == layout
return self._parameter_layout_dict[param_name][0:6] == layout
def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool:
"""Verify parameter shape"""