!65761 pipeline parallel inference
Merge pull request !65761 from chenweifeng/feature-2.3-pipeline-inference
This commit is contained in:
commit
0aa21ef058
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -28,6 +29,7 @@
|
|||
#include "frontend/parallel/tensor_layout/tensor_layout.h"
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "frontend/parallel/parameter_manager.h"
|
||||
#include "frontend/parallel/tensor_layout/shared_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -226,6 +228,23 @@ py::dict GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr &sub_graph, const F
|
|||
}
|
||||
return cnode_info_dict;
|
||||
}
|
||||
|
||||
std::tuple<bool, bool, int64_t, int64_t> GetSharedParameterInfo(const AnfNodePtr ¶m) {
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
bool is_pipeline_shared = false;
|
||||
bool is_send = false;
|
||||
int64_t peer_rank = 0;
|
||||
int64_t sr_tag = 0;
|
||||
|
||||
auto shared_params = param->user_data<parallel::SharedParameter>();
|
||||
if (shared_params) {
|
||||
is_pipeline_shared = shared_params->pipeline_shared();
|
||||
is_send = shared_params->is_send();
|
||||
peer_rank = shared_params->peer_rank();
|
||||
sr_tag = shared_params->sr_tag();
|
||||
}
|
||||
return std::tuple(is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
||||
|
@ -261,8 +280,10 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
|||
int64_t field_size = tensor_layout->get_field_size();
|
||||
bool uniform_split = tensor_layout->uniform_split();
|
||||
const std::string &opt_shard_group = tensor_layout->opt_shard_group();
|
||||
py::tuple layout =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
|
||||
auto [is_pipeline_shared, is_send, peer_rank, sr_tag] = GetSharedParameterInfo(para);
|
||||
py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
|
||||
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
for (auto &name : names) {
|
||||
dict[py::str(name)] = layout;
|
||||
}
|
||||
|
@ -285,8 +306,12 @@ py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
|
|||
int64_t field_size = layout->get_field_size();
|
||||
bool uniform_split = layout->get_uniform_split();
|
||||
const std::string &opt_shard_group = layout->get_opt_shard_group();
|
||||
py::tuple layout_tuple =
|
||||
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
|
||||
bool is_pipeline_shared = layout->pipeline_shared();
|
||||
bool is_send = layout->is_send();
|
||||
int64_t peer_rank = layout->peer_rank();
|
||||
int64_t sr_tag = layout->sr_tag();
|
||||
py::tuple layout_tuple = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
|
||||
opt_shard_group, is_pipeline_shared, is_send, peer_rank, sr_tag);
|
||||
dict[py::str(name)] = layout_tuple;
|
||||
}
|
||||
return dict;
|
||||
|
|
|
@ -73,7 +73,7 @@ void SpreadFineGrainedInterleavedIndexForForwardCommNodes(const CNodePtr &cnode,
|
|||
if (pre_cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
// BFS end search contition
|
||||
// BFS end search condition
|
||||
if (IsPrimitiveCNode(pre_cnode, prim::kPrimStridedSlice) &&
|
||||
GetCNodePrimitive(pre_cnode)->HasAttr(kAttrFineGrainedInterleavedBlockIndex)) {
|
||||
pre_cnode->AddAttr("fine_grained_interleaved_border", MakeValue<size_t>(0));
|
||||
|
@ -219,6 +219,10 @@ void LabelFineGrainedInterleavedIndex(const FuncGraphPtr &graph) {
|
|||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
if (!IsTraining(manager)) {
|
||||
return;
|
||||
}
|
||||
|
||||
FuncGraphPtr forward_graph = graph;
|
||||
FuncGraphPtr backward_graph = graph;
|
||||
auto context = MsContext::GetInstance();
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
#include "frontend/parallel/graph_util/graph_splitter.h"
|
||||
#include "frontend/parallel/tensor_layout/shared_parameter.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
|
@ -71,6 +72,45 @@ static AbstractBasePtr GetRealAbstract(const AnfNodePtr &node) {
|
|||
return node->abstract();
|
||||
}
|
||||
|
||||
void PipelineTransformer::UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op,
|
||||
bool is_send) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(communcate_op);
|
||||
|
||||
if (!node->isa<Parameter>()) {
|
||||
return;
|
||||
}
|
||||
auto root_param = node;
|
||||
if (node->func_graph() != root_) {
|
||||
root_param = GetArgumentsByParameter(node);
|
||||
MS_EXCEPTION_IF_NULL(root_param);
|
||||
}
|
||||
|
||||
// get communication info from cnode.
|
||||
auto prim = GetCNodePrimitive(communcate_op);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto sr_tag_attr = prim->GetAttr(SR_TAG);
|
||||
MS_EXCEPTION_IF_NULL(sr_tag_attr);
|
||||
auto sr_tag = GetValue<int64_t>(sr_tag_attr);
|
||||
auto peer_rank_attr = is_send ? prim->GetAttr(DEST_RANK) : prim->GetAttr(SRC_RANK);
|
||||
MS_EXCEPTION_IF_NULL(peer_rank_attr);
|
||||
auto peer_rank = GetValue<int64_t>(peer_rank_attr);
|
||||
auto group_attr = prim->GetAttr(GROUP);
|
||||
MS_EXCEPTION_IF_NULL(group_attr);
|
||||
auto group = GetValue<std::string>(group_attr);
|
||||
|
||||
// Use global rank since local group may not exist after loading checkpoint.
|
||||
auto rank_list = g_device_manager->FindRankListByHashName(group);
|
||||
peer_rank = rank_list.at(peer_rank);
|
||||
|
||||
// update tensor layout.
|
||||
auto param = root_param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto shared_parameters = std::make_shared<SharedParameter>(true, is_send, peer_rank, sr_tag);
|
||||
param->set_user_data<SharedParameter>(shared_parameters);
|
||||
}
|
||||
|
||||
TensorInfo PipelineTransformer::GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param) {
|
||||
if (is_param) {
|
||||
auto inputs_tensor_info = op_info_pair.first->inputs_tensor_info();
|
||||
|
@ -310,7 +350,7 @@ size_t PipelineTransformer::GetBatchAxisForInput(const AnfNodeIndexSet &input_no
|
|||
}
|
||||
}
|
||||
}
|
||||
if (batch_axis_count != kSizeOne) {
|
||||
if (is_train_ && batch_axis_count != kSizeOne) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For pipeline parallelism, micro_size partitioning of the input along a certain dimension is and "
|
||||
<< "is only allowed, but it is found that " << batch_axis_count << " to be partitioned.";
|
||||
|
@ -1112,7 +1152,9 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
|
|||
}
|
||||
(void)parameter_color_map_[root_param].insert(user_stage);
|
||||
auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
|
||||
return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
|
||||
auto recv_node = InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter);
|
||||
UpdateParameterSharedInfo(root_param, recv_node, false);
|
||||
return recv_node;
|
||||
}
|
||||
// insert send
|
||||
if (Reuse(argument, user_stage, ops, DEST_RANK)) {
|
||||
|
@ -1121,6 +1163,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
|
|||
auto send_out = InsertSend(argument, user_stage, stage_, micro);
|
||||
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
|
||||
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
|
||||
UpdateParameterSharedInfo(argument, send_out.depend, true);
|
||||
return send_out.depend;
|
||||
}
|
||||
|
||||
|
@ -1374,7 +1417,7 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchSend(const AnfNodePtr &node, b
|
|||
for (auto &user : shared_cell_users_) {
|
||||
auto cuser = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cuser);
|
||||
auto value = cuser->GetPrimalAttr(MICRO);
|
||||
auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
send_input = single_pipeline_end ? user : CreateTupleGetItemNode(main_graph_, user, end_index);
|
||||
(void)(depends.emplace_back(GenNewSendFromOld(node, send_input, value)));
|
||||
|
@ -1388,6 +1431,11 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
|
|||
SeparateParamBorder(nodes, true, &pipeline_params, &pipeline_ends);
|
||||
std::vector<AnfNodePtr> sends;
|
||||
SetNodeAbstract(pipeline_ends);
|
||||
|
||||
// Create root graph output before modify subgraph(shared cell).
|
||||
// This process order is crucial when the output of subgraph is directly used as root graph.
|
||||
auto zero_outputs = GetZeroOutputs(main_graph_);
|
||||
|
||||
size_t ends_size = pipeline_ends.size();
|
||||
bool single_pipeline_end = ends_size == 1;
|
||||
if (single_pipeline_end) {
|
||||
|
@ -1408,7 +1456,9 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
|
|||
}
|
||||
for (auto &node : pipeline_params) {
|
||||
auto params = FetchSend(node, true, false, 0);
|
||||
(void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
|
||||
if (is_train_) {
|
||||
(void)std::copy(params.begin(), params.end(), std::back_inserter(sends));
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < ends_size; i++) {
|
||||
auto node = pipeline_ends[i];
|
||||
|
@ -1416,7 +1466,6 @@ void PipelineTransformer::HandleGraphOutputs(const std::vector<AnfNodePtr> &node
|
|||
(void)std::copy(ends.begin(), ends.end(), std::back_inserter(sends));
|
||||
}
|
||||
auto make_tuple = CreateMakeTupleNode(main_graph_, sends);
|
||||
auto zero_outputs = GetZeroOutputs(main_graph_);
|
||||
std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
|
||||
auto out_node = main_graph_->NewCNode(out);
|
||||
out_node->set_abstract(zero_outputs->abstract());
|
||||
|
@ -1486,7 +1535,11 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, b
|
|||
recv_input = user->input(input_pos);
|
||||
recv = GenNewRecvFromOld(node, recv_input, value);
|
||||
for (auto &share_user : shared_cell_users_) {
|
||||
manager_->SetEdge(share_user, input_pos, recv);
|
||||
if (is_train_) {
|
||||
manager_->SetEdge(share_user, input_pos, recv);
|
||||
} else {
|
||||
manager_->SetEdge(share_user, input_pos, recv_input);
|
||||
}
|
||||
}
|
||||
node->set_user_data<bool>(ORIGIN_INPUT_IS_PARAM, std::make_shared<bool>(true));
|
||||
} else {
|
||||
|
@ -1501,7 +1554,7 @@ std::vector<AnfNodePtr> PipelineTransformer::FetchRecv(const AnfNodePtr &node, b
|
|||
for (auto &user : shared_cell_users_) {
|
||||
auto cuser = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cuser);
|
||||
auto value = cuser->GetPrimalAttr(MICRO);
|
||||
auto value = shared_cell_users_.size() > 1 ? cuser->GetPrimalAttr(MICRO) : MakeValue(int64_t(0));
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (enable_share_cell_ || !is_train_) {
|
||||
auto recv_tensor = TensorConstructUtils::CreateZerosTensor(kFloat16, {1});
|
||||
|
@ -1549,6 +1602,7 @@ void PipelineTransformer::ResetSharedCellParamAndArgu(
|
|||
MS_LOG(DEBUG) << "The shared cell origin params size is " << params.size() << ", new params size is "
|
||||
<< new_params.size();
|
||||
manager_->SetParameters(shared_cell_, new_params);
|
||||
shared_cell_->set_fv_param_count(new_params.size());
|
||||
// set call inputs
|
||||
size_t user_index = 0;
|
||||
for (auto &user : shared_cell_users_) {
|
||||
|
|
|
@ -118,6 +118,7 @@ class PipelineTransformer {
|
|||
bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> ¶meters,
|
||||
const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage);
|
||||
size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const;
|
||||
void UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op, bool is_send);
|
||||
FuncGraphManagerPtr manager_;
|
||||
int64_t stage_;
|
||||
FuncGraphPtr root_;
|
||||
|
|
|
@ -3025,6 +3025,50 @@ static void MoveMicroMirrorOutCallFunc(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
static void BroadcastMultiOutputs(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, const Group &group) {
|
||||
auto output = root->get_return()->input(1)->cast<CNodePtr>();
|
||||
auto output_abstract = output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(output_abstract);
|
||||
auto abstract_tuple = output_abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
auto abstract_list = abstract_tuple->elements();
|
||||
|
||||
AnfNodePtrList make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < abstract_list.size(); i++) {
|
||||
auto abstract = abstract_list[i];
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
|
||||
// TupleGetItem
|
||||
auto idx = NewValueNode(SizeToLong(i));
|
||||
CNodePtr tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_abstract(abstract);
|
||||
|
||||
// Depend: prevent disorder and CSE
|
||||
if (i > 0) {
|
||||
tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimDepend), tuple_getitem, make_tuple_input[i]});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_abstract(abstract);
|
||||
}
|
||||
|
||||
// Allreduce
|
||||
CNodePtr allreduce = root->NewCNode({NewValueNode(prim::kPrimAllReduce), tuple_getitem});
|
||||
MS_EXCEPTION_IF_NULL(allreduce);
|
||||
allreduce->set_abstract(abstract);
|
||||
common::AnfAlgo::SetNodeAttr(OP, MakeValue(REDUCE_OP_SUM), allreduce);
|
||||
common::AnfAlgo::SetNodeAttr(GROUP, MakeValue(group.name()), allreduce);
|
||||
// Disable GE allreduce fusion.
|
||||
common::AnfAlgo::SetNodeAttr(FUSION, MakeValue(static_cast<int64_t>(0)), allreduce);
|
||||
|
||||
make_tuple_input.push_back(allreduce);
|
||||
}
|
||||
|
||||
CNodePtr make_tuple_node = root->NewCNode(make_tuple_input);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||
make_tuple_node->set_abstract(abstract_tuple);
|
||||
(void)manager->Replace(output, make_tuple_node);
|
||||
}
|
||||
|
||||
static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
auto pipeline_result_broadcast = parallel::ParallelContext::GetInstance()->pipeline_result_broadcast();
|
||||
|
@ -3032,12 +3076,6 @@ static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManager
|
|||
return;
|
||||
}
|
||||
|
||||
auto return_node = root->get_return();
|
||||
const auto &abstract = return_node->abstract();
|
||||
if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage();
|
||||
Group group;
|
||||
if (g_device_manager->CreateGroup(rank_list, &group) != SUCCESS) {
|
||||
|
@ -3045,6 +3083,12 @@ static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManager
|
|||
<< rank_list;
|
||||
}
|
||||
|
||||
auto return_node = root->get_return();
|
||||
const auto &abstract = return_node->abstract();
|
||||
if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
return BroadcastMultiOutputs(root, manager, group);
|
||||
}
|
||||
|
||||
InsertAllReduceToNodeInput(return_node, group.name(), PARALLEL_RESULT_BROADCAST);
|
||||
return_node->input(1)->set_abstract(abstract);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2024 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
class SharedParameter {
|
||||
public:
|
||||
SharedParameter(bool pipeline_shared, bool is_send, int64_t peer_rank, int64_t sr_tag)
|
||||
: pipeline_shared_(pipeline_shared), is_send_(is_send), peer_rank_(peer_rank), sr_tag_(sr_tag) {}
|
||||
~SharedParameter() = default;
|
||||
|
||||
void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; }
|
||||
bool pipeline_shared() const { return pipeline_shared_; }
|
||||
|
||||
void set_is_send(bool is_send) { is_send_ = is_send; }
|
||||
bool is_send() const { return is_send_; }
|
||||
|
||||
void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; }
|
||||
int64_t peer_rank() const { return peer_rank_; }
|
||||
|
||||
void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; }
|
||||
int64_t sr_tag() const { return sr_tag_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "SharedParameter";
|
||||
|
||||
private:
|
||||
bool pipeline_shared_ = false;
|
||||
bool is_send_ = false;
|
||||
int64_t peer_rank_{0};
|
||||
int64_t sr_tag_{0};
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_SHARED_PARAMETER_TENSOR_LAYOUT_H_
|
|
@ -28,6 +28,7 @@
|
|||
#include "utils/system/sha256.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "frontend/parallel/tensor_layout/shared_parameter.h"
|
||||
#include "mindspore/core/utils/file_utils.h"
|
||||
|
||||
#if defined(__linux__) && defined(WITH_BACKEND)
|
||||
|
@ -76,6 +77,13 @@ void BuildLayout(const FuncGraphPtr &func_graph, mind_ir::ModelProto *model) {
|
|||
layoutProto->set_field_size(field_size);
|
||||
layoutProto->set_uniform_split(uniform_split);
|
||||
layoutProto->set_opt_shard_group(opt_shard_group);
|
||||
auto shared_param = para->user_data<parallel::SharedParameter>();
|
||||
if (shared_param) {
|
||||
layoutProto->set_pipeline_shared(shared_param->pipeline_shared());
|
||||
layoutProto->set_is_send(shared_param->is_send());
|
||||
layoutProto->set_peer_rank(shared_param->peer_rank());
|
||||
layoutProto->set_sr_tag(shared_param->sr_tag());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2299,6 +2299,18 @@ const LayoutMap MSANFModelParser::ParseLayout(const mind_ir::ModelProto &model_p
|
|||
cur_layout->set_uniform_split(uniform_spilt);
|
||||
cur_layout->set_opt_shard_group(opt_shard_group);
|
||||
|
||||
// Check optional field for backward compatibility.
|
||||
if (layout_proto.has_pipeline_shared()) {
|
||||
bool pipeline_shared = layout_proto.pipeline_shared();
|
||||
bool is_send = layout_proto.is_send();
|
||||
int64_t peer_rank = layout_proto.peer_rank();
|
||||
int64_t sr_tag = layout_proto.sr_tag();
|
||||
|
||||
cur_layout->set_pipeline_shared(pipeline_shared);
|
||||
cur_layout->set_is_send(is_send);
|
||||
cur_layout->set_peer_rank(peer_rank);
|
||||
cur_layout->set_sr_tag(sr_tag);
|
||||
}
|
||||
ret[name] = cur_layout;
|
||||
}
|
||||
return ret;
|
||||
|
|
|
@ -42,6 +42,14 @@ class Layout {
|
|||
void set_uniform_split(bool uniform_split) { uniform_split_ = uniform_split; }
|
||||
const std::string &get_opt_shard_group() const { return opt_shard_group_; }
|
||||
void set_opt_shard_group(const std::string &opt_shard_group) { opt_shard_group_ = opt_shard_group; }
|
||||
void set_pipeline_shared(bool pipeline_shared) { pipeline_shared_ = pipeline_shared; }
|
||||
bool pipeline_shared() const { return pipeline_shared_; }
|
||||
void set_is_send(bool is_send) { is_send_ = is_send; }
|
||||
bool is_send() const { return is_send_; }
|
||||
void set_peer_rank(int64_t peer_rank) { peer_rank_ = peer_rank; }
|
||||
int64_t peer_rank() const { return peer_rank_; }
|
||||
void set_sr_tag(int64_t sr_tag) { sr_tag_ = sr_tag; }
|
||||
int64_t sr_tag() const { return sr_tag_; }
|
||||
|
||||
private:
|
||||
std::vector<int64_t> device_arrangement_{};
|
||||
|
@ -50,6 +58,11 @@ class Layout {
|
|||
int64_t field_size_ = 0;
|
||||
bool uniform_split_ = false;
|
||||
std::string opt_shard_group_ = "";
|
||||
// pipeline stage shared param info
|
||||
bool pipeline_shared_ = false;
|
||||
bool is_send_ = false;
|
||||
int64_t peer_rank_{0};
|
||||
int64_t sr_tag_{0};
|
||||
};
|
||||
using LayoutPtr = std::shared_ptr<Layout>;
|
||||
using LayoutMap = std::map<string, LayoutPtr>;
|
||||
|
|
|
@ -233,6 +233,10 @@ message LayoutProto {
|
|||
optional int64 field_size = 5;
|
||||
optional bool uniform_split = 6;
|
||||
optional string opt_shard_group = 7;
|
||||
optional bool pipeline_shared = 8;
|
||||
optional bool is_send = 9;
|
||||
optional int64 peer_rank = 10;
|
||||
optional int64 sr_tag = 11;
|
||||
}
|
||||
|
||||
message PrimitiveProto {
|
||||
|
|
|
@ -18,8 +18,9 @@ from __future__ import absolute_import
|
|||
from mindspore.parallel.algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||
set_algo_parameters
|
||||
from mindspore.parallel.checkpoint_transform import rank_list_for_transform, transform_checkpoint_by_rank, \
|
||||
transform_checkpoints, merge_pipeline_strategys
|
||||
transform_checkpoints, merge_pipeline_strategys, sync_pipeline_shared_parameters
|
||||
from mindspore.parallel.shard import shard
|
||||
|
||||
__all__ = ["set_algo_parameters", "reset_algo_parameters", "get_algo_parameters", "rank_list_for_transform",
|
||||
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard"]
|
||||
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard",
|
||||
"sync_pipeline_shared_parameters"]
|
||||
|
|
|
@ -22,6 +22,7 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
||||
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
|
||||
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
|
||||
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
||||
|
@ -29,7 +30,7 @@ from mindspore.parallel._parallel_serialization import _rank_list_for_transform_
|
|||
|
||||
|
||||
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
|
||||
"transform_checkpoints"]
|
||||
"transform_checkpoints", "sync_pipeline_shared_parameters"]
|
||||
|
||||
|
||||
def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
|
||||
|
@ -336,3 +337,121 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|||
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
|
||||
del param_total_dict_copy
|
||||
del param_total_dict
|
||||
|
||||
|
||||
def _sync_params(name, param, layout):
|
||||
"""synchronize single parameter"""
|
||||
if len(layout) < 10:
|
||||
ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name)
|
||||
return
|
||||
|
||||
pipeline_shared = layout[6]
|
||||
if not pipeline_shared:
|
||||
return
|
||||
|
||||
is_send = layout[7]
|
||||
peer_rank = layout[8]
|
||||
sr_tag = layout[9]
|
||||
|
||||
class SharedParameterSyncCell(ms.nn.Cell):
|
||||
"""synchronize cell"""
|
||||
def __init__(self, param, is_send, peer_rank, sr_tag):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
self.is_send = is_send
|
||||
self.ret = ms.Tensor([0])
|
||||
|
||||
from mindspore.ops.operations._inner_ops import Send, Receive
|
||||
if self.is_send:
|
||||
self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank)
|
||||
else:
|
||||
self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype)
|
||||
|
||||
def construct(self):
|
||||
if self.is_send:
|
||||
out = self.send(self.param)
|
||||
return ms.ops.functional.depend(self.ret, out)
|
||||
|
||||
self.param = self.receive(self.ret)
|
||||
return ms.ops.functional.depend(self.ret, self.param)
|
||||
|
||||
sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag)
|
||||
sync_net()
|
||||
|
||||
|
||||
def sync_pipeline_shared_parameters(net):
|
||||
"""synchronize pipeline parallel stage shared parameters.
|
||||
Parameters may be shared between different stages in pipeline parallel inference. For example, `embedding table` is
|
||||
shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
|
||||
to perform synchronization after `embedding table` changes.
|
||||
|
||||
Note:
|
||||
The network should be compiled before synchronize pipeline parallel stage shared parameters.
|
||||
|
||||
Args:
|
||||
net (nn.Cell): the inference network.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import nn, Parameter, Tensor
|
||||
>>> class VocabEmbedding(nn.Cell):
|
||||
... def __init__(self, vocab_size, embedding_size):
|
||||
... super().__init__()
|
||||
... self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size])), name='embedding')
|
||||
...
|
||||
... def construct(self, x):
|
||||
... output = self.gather(self.embedding_table, x, 0)
|
||||
... return output, self.embedding_table.value()
|
||||
...
|
||||
>>> class LMHead(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.matmul = ops.MatMul(transpose_b=True)
|
||||
...
|
||||
... def construct(self, state, embed):
|
||||
... state = state.reshape(-1, state.shape[-1])
|
||||
... return self.matmul(state, embed)
|
||||
...
|
||||
>>> class Network(nn.Cell):
|
||||
... @lazy_inline
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.word_embedding = VocabEmbedding(vocab_size=4, embedding_size=4)
|
||||
... self.head = LMHead()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... x, embed = self.word_embedding(x)
|
||||
... x = self.head(x, embed)
|
||||
... return x
|
||||
>>>
|
||||
>>> net = Network()
|
||||
>>> net.word_embedding.pipeline_stage = 0
|
||||
>>> net.head.pipeline_stage = 1
|
||||
>>> x = Tensor(np.ones((8, 4))
|
||||
>>> net.compile()
|
||||
>>> ms.parallel.sync_pipeline_shared_parameters(net)
|
||||
>>> print(net.word_embedding.embedding_table.asnumpy())
|
||||
>>> [[1. 1. 1. 1.]
|
||||
[1. 1. 1. 1.]
|
||||
[1. 1. 1. 1.]
|
||||
[1. 1. 1. 1.]]
|
||||
"""
|
||||
|
||||
layout_dict = net.parameter_layout_dict
|
||||
if _is_in_auto_parallel_mode() and not layout_dict:
|
||||
from mindspore.common.api import _get_parameter_layout
|
||||
layout_dict = _get_parameter_layout()
|
||||
|
||||
# switch to standalone mode
|
||||
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
|
||||
full_batch = ms.context.get_auto_parallel_context("full_batch")
|
||||
ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
|
||||
|
||||
# synchronize shared parameter
|
||||
for name, param in net.parameters_and_names():
|
||||
if name in layout_dict:
|
||||
_sync_params(name, param, layout_dict[name])
|
||||
|
||||
# restore parallel context
|
||||
ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
|
||||
|
|
|
@ -63,6 +63,7 @@ from mindspore.parallel._parallel_serialization import _convert_to_list, _conver
|
|||
_restore_group_info_list
|
||||
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
||||
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
||||
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
||||
from mindspore.train._utils import read_proto
|
||||
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
||||
split_mindir, split_dynamic_mindir
|
||||
|
@ -559,6 +560,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|||
|
||||
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
||||
"""Convert nn.Cell to param_list."""
|
||||
sync_pipeline_shared_parameters(save_obj)
|
||||
param_list = []
|
||||
parameter_layout_dict = save_obj.parameter_layout_dict
|
||||
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
import numpy as np
|
||||
import mindspore
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import lazy_inline, context, nn, Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
pipeline_stages = 4
|
||||
|
||||
|
||||
class FC(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x):
|
||||
return self.matmul(x, self.w)
|
||||
|
||||
|
||||
class WordEmbedding(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x):
|
||||
return self.matmul(x, self.w), self.w
|
||||
|
||||
|
||||
class LMHead(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super().__init__()
|
||||
self.w = Parameter(Tensor(np.ones(shape), mindspore.float32), name="weight")
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
|
||||
def construct(self, x, w):
|
||||
x = self.matmul1(x, self.w)
|
||||
return self.matmul2(x, w), x + x
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
@lazy_inline
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
shape = (8, 8)
|
||||
self.word_embedding = WordEmbedding(shape)
|
||||
self.decoder1 = FC(shape)
|
||||
self.decoder2 = FC(shape)
|
||||
self.lm_head = LMHead(shape)
|
||||
|
||||
self.word_embedding.matmul.shard(((1, 1), (1, 1)))
|
||||
self.decoder1.matmul.shard(((1, 1), (1, 1)))
|
||||
self.decoder2.matmul.shard(((1, 1), (1, 1)))
|
||||
self.lm_head.matmul1.shard(((1, 1), (1, 1)))
|
||||
self.lm_head.matmul2.shard(((1, 1), (1, 1)))
|
||||
|
||||
self.word_embedding.pipeline_stage = 0
|
||||
self.decoder1.pipeline_stage = 1
|
||||
self.decoder2.pipeline_stage = 2
|
||||
self.lm_head.pipeline_stage = 3
|
||||
|
||||
def construct(self, x):
|
||||
x, w = self.word_embedding(x)
|
||||
x = self.decoder1(x)
|
||||
x = self.decoder2(x)
|
||||
x, y = self.lm_head(x, w)
|
||||
return x, y
|
||||
|
||||
|
||||
class PipelineCellInference(nn.Cell):
|
||||
def __init__(self, network, micro_batch_num):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.micro_batch_num = micro_batch_num
|
||||
self.concat = P.Concat()
|
||||
|
||||
def construct(self, x):
|
||||
ret_x = ()
|
||||
ret_y = ()
|
||||
for i in range(self.micro_batch_num):
|
||||
micro_batch_size = x.shape[0] // self.micro_batch_num
|
||||
start = micro_batch_size * i
|
||||
end = micro_batch_size * (i + 1)
|
||||
|
||||
micro_input = x[start:end]
|
||||
ret1, ret2 = self.network(micro_input)
|
||||
ret_x = ret_x + (ret1,)
|
||||
ret_y = ret_y + (ret2,)
|
||||
|
||||
ret_x = self.concat(ret_x)
|
||||
ret_y = self.concat(ret_y)
|
||||
return ret_x, ret_y
|
||||
|
||||
|
||||
def get_stage_id():
|
||||
rank_size = D.get_group_size()
|
||||
rank_id = D.get_rank()
|
||||
rank_per_stage = rank_size // pipeline_stages
|
||||
return rank_id // rank_per_stage
|
||||
|
||||
|
||||
def test_pipeline_inference_basic():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Micro batch split
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages)
|
||||
net = PipelineCellInference(Net(), micro_batch_num=2)
|
||||
net.set_train(False)
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
ret = net(x)
|
||||
|
||||
expect = [[np.zeros(shape, np.float32), np.zeros(shape, np.float32)],
|
||||
[np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]]
|
||||
is_last_stage = get_stage_id() == pipeline_stages - 1
|
||||
assert np.allclose(ret[0].asnumpy(), expect[is_last_stage][0])
|
||||
assert np.allclose(ret[1].asnumpy(), expect[is_last_stage][1])
|
||||
|
||||
|
||||
def test_pipeline_inference_broadcast():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Broadcast last stage result, multi-output.
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
|
||||
net = PipelineCellInference(Net(), micro_batch_num=4)
|
||||
net.set_train(False)
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
ret = net(x)
|
||||
print(ret)
|
||||
|
||||
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
|
||||
assert np.allclose(ret[0].asnumpy(), expect[0])
|
||||
assert np.allclose(ret[1].asnumpy(), expect[1])
|
||||
|
||||
|
||||
def test_pipeline_inference_single_micro_batch():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Broadcast last stage result, without micro batch split
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
|
||||
net = PipelineCellInference(Net(), micro_batch_num=1)
|
||||
net.set_train(False)
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
ret = net(x)
|
||||
|
||||
print(ret)
|
||||
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
|
||||
assert np.allclose(ret[0].asnumpy(), expect[0])
|
||||
assert np.allclose(ret[1].asnumpy(), expect[1])
|
||||
|
||||
|
||||
def test_pipeline_inference_without_wrapper():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Broadcast last stage result, without wrapper
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
|
||||
net = Net()
|
||||
net.set_train(False)
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
ret = net(x)
|
||||
print(ret)
|
||||
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
|
||||
assert np.allclose(ret[0].asnumpy(), expect[0])
|
||||
assert np.allclose(ret[1].asnumpy(), expect[1])
|
||||
|
||||
|
||||
class PipelineCellInferenceSingleOutput(nn.Cell):
|
||||
def __init__(self, network, micro_batch_num):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.micro_batch_num = micro_batch_num
|
||||
self.concat = P.Concat()
|
||||
|
||||
def construct(self, x):
|
||||
ret = ()
|
||||
for i in range(self.micro_batch_num):
|
||||
micro_batch_size = x.shape[0] // self.micro_batch_num
|
||||
start = micro_batch_size * i
|
||||
end = micro_batch_size * (i + 1)
|
||||
|
||||
micro_input = x[start:end]
|
||||
ret1, _ = self.network(micro_input)
|
||||
ret = ret + (ret1,)
|
||||
|
||||
ret = self.concat(ret)
|
||||
return ret
|
||||
|
||||
|
||||
def test_pipeline_inference_single_output():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Micro batch split
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
|
||||
net = PipelineCellInferenceSingleOutput(Net(), micro_batch_num=2)
|
||||
net.set_train(False)
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
ret = net(x)
|
||||
|
||||
print(ret)
|
||||
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 5)]
|
||||
is_last_stage = get_stage_id() == pipeline_stages - 1
|
||||
assert np.allclose(ret.asnumpy(), expect[is_last_stage])
|
||||
|
||||
|
||||
def test_pipeline_inference_shared_params():
|
||||
"""
|
||||
Feature: Pipeline parallel inference
|
||||
Description: Shared parameters synchronize
|
||||
Expectation: success
|
||||
"""
|
||||
D.init()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", full_batch=True,
|
||||
pipeline_stages=pipeline_stages, pipeline_result_broadcast=True)
|
||||
net = Net()
|
||||
net.set_train(False)
|
||||
|
||||
if get_stage_id() == pipeline_stages - 1:
|
||||
shape, dtype = net.word_embedding.w.shape, net.word_embedding.w.dtype
|
||||
net.word_embedding.w.set_data(Tensor(np.zeros(shape), dtype))
|
||||
|
||||
shape = (8, 8)
|
||||
x = Tensor(np.ones(shape), mindspore.float32)
|
||||
# compile and synchronize
|
||||
net.compile(x)
|
||||
sync_pipeline_shared_parameters(net)
|
||||
|
||||
ret = net(x)
|
||||
print(ret)
|
||||
expect = [np.ones(shape, np.float32) * pow(8, 5), np.ones(shape, np.float32) * pow(8, 4) * 2]
|
||||
assert np.allclose(ret[0].asnumpy(), expect[0])
|
||||
assert np.allclose(ret[1].asnumpy(), expect[1])
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_msrun_pipeline_parallel_inference():
|
||||
'''
|
||||
Feature: Pipeline inference optimizer parallel.
|
||||
Description: Test pipeline inference.
|
||||
Expectation: Run success.
|
||||
'''
|
||||
ret = os.system("msrun --worker_num=4 --local_worker_num=4 --master_addr=127.0.0.1 --master_port=10969 "
|
||||
"--join=True --log_dir=./pipeline_inference_logs pytest -s -v "
|
||||
"pipeline_inference.py")
|
||||
assert ret == 0
|
|
@ -57,7 +57,8 @@ def test_get_parameter_layout():
|
|||
weight_layout = ([2, 4], [0, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
assert net.parameter_layout_dict["x"][0:6] == expect_dict["x"]
|
||||
assert net.parameter_layout_dict["w1"][0:6] == expect_dict["w1"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -111,9 +111,8 @@ def test_pipeline_inference_first_stage():
|
|||
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
|
||||
phase = compile_infer_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs_has('Send-0', ['network.fc0.weight'], graph_id=0)
|
||||
assert validator.check_node_inputs_has('Send-1', ['pipeline_inference_SimpleNet_construct'], graph_id=0)
|
||||
assert validator.check_node_inputs_has('Send-2', ['pipeline_inference_SimpleNet_construct'], graph_id=0)
|
||||
assert validator.check_node_inputs_has('Send-0', ['pipeline_inference_SimpleNet_construct'], graph_id=1)
|
||||
assert validator.check_node_inputs_has('Send-1', ['pipeline_inference_SimpleNet_construct'], graph_id=1)
|
||||
|
||||
|
||||
def test_pipeline_inference_last_stage():
|
||||
|
@ -133,8 +132,7 @@ def test_pipeline_inference_last_stage():
|
|||
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
|
||||
phase = compile_infer_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs_has('Receive-0', ['network.fc0.weight'], graph_id=1)
|
||||
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'Receive-0', 'Receive-2'],
|
||||
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'network.fc0.weight', 'Receive-1'],
|
||||
graph_id=1)
|
||||
|
||||
|
||||
|
@ -156,7 +154,6 @@ def test_pipeline_inference_result_broadcast():
|
|||
x = Tensor(np.ones((batch_size, hidden_size)), mindspore.float32)
|
||||
phase = compile_infer_net(net, x)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs_has('Receive-0', ['network.fc0.weight'], graph_id=1)
|
||||
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'Receive-0', 'Receive-2'],
|
||||
assert validator.check_node_inputs_has('call @graph_0', ['network.fc1.weight', 'network.fc0.weight', 'Receive-1'],
|
||||
graph_id=1)
|
||||
assert validator.check_node_inputs_has('AllReduce-0', ['Depend-0'], graph_id=1)
|
||||
|
|
|
@ -135,7 +135,10 @@ def test_grad_sens_parameter_type():
|
|||
b_layout = ([64], [-1, -1], [64, 64], 0, True, '')
|
||||
sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
|
||||
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
assert net.parameter_layout_dict['x'][0:6] == expect_dict['x']
|
||||
assert net.parameter_layout_dict['y'][0:6] == expect_dict['y']
|
||||
assert net.parameter_layout_dict['b'][0:6] == expect_dict['b']
|
||||
assert net.parameter_layout_dict['sens'][0:6] == expect_dict['sens']
|
||||
|
||||
|
||||
def test_grad_sens_tensor_type():
|
||||
|
|
|
@ -73,7 +73,7 @@ class ParallelValidator:
|
|||
|
||||
if param_name not in self._parameter_layout_dict.keys():
|
||||
return False
|
||||
return self._parameter_layout_dict[param_name] == layout
|
||||
return self._parameter_layout_dict[param_name][0:6] == layout
|
||||
|
||||
def check_parameter_shape(self, param_name: str, shape: [tuple, list]) -> bool:
|
||||
"""Verify parameter shape"""
|
||||
|
|
Loading…
Reference in New Issue