fix_pipeline_split_param_shared_bug

This commit is contained in:
lichenever 2020-12-03 15:07:57 +08:00
parent 78e131cf15
commit 818e920f02
4 changed files with 117 additions and 19 deletions

View File

@ -28,6 +28,8 @@
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/node_check.h"
#include "ir/anf.h"
#include "base/core_ops.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
@ -136,6 +138,11 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// handle send/recv a parameter
if (node->isa<Parameter>()) {
MS_LOG(INFO) << "parameter: " << node->ToString() << " need to be send/recv.";
return std::make_pair(nullptr, nullptr);
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr op_info = nullptr;
@ -170,6 +177,23 @@ std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const A
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
}
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetParameterPair(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto node_users = manager_->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node);
if (!IsPipelineCareNode(user_node)) {
continue;
}
auto op_info = CreateOpInfo(user_node);
MS_EXCEPTION_IF_NULL(op_info);
auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1];
return std::make_pair(nullptr, std::make_shared<TensorInfo>(tensor_info));
}
return std::make_pair(nullptr, nullptr);
}
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
auto need_coloring = true;
while (need_coloring) {
@ -240,6 +264,7 @@ void PipelineTransformer::HandleSharedParameter() {
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "");
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, make_tuple};
auto depend = graph->NewCNode(depend_input);
depend->set_abstract(parameter->abstract());
manager_->SetEdge(node, user.second, depend);
break;
} else {
@ -301,7 +326,12 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
auto send_op = CreatOpInstance(attrs, SEND, "send");
auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node);
auto op_info_pair = GetOpInfo(parameter);
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
if (parameter->isa<Parameter>()) {
op_info_pair = GetParameterPair(parameter);
} else {
op_info_pair = GetOpInfo(parameter);
}
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape();
@ -314,6 +344,8 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend");
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
auto depend = graph->NewCNode(depend_input);
auto abstract = parameter->abstract();
depend->set_abstract(abstract);
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
return send_out;
}
@ -324,7 +356,12 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
recv_tag += 1;
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank));
auto op_info_pair = GetOpInfo(node);
std::pair<OperatorInfoPtr, TensorInfoPtr> op_info_pair;
if (node->isa<Parameter>()) {
op_info_pair = GetParameterPair(node);
} else {
op_info_pair = GetOpInfo(node);
}
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape();
@ -333,12 +370,19 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second);
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype};
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv");
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_};
std::vector<AnfNodePtr> recv_input;
if (node->isa<Parameter>()) {
recv_input = {NewValueNode(recv_op), node};
} else {
recv_input = {NewValueNode(recv_op), virtual_param_};
}
auto recv = graph->NewCNode(recv_input);
auto node_abstract = node->abstract();
recv->set_abstract(node_abstract);
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
recv->set_user_data<OperatorInfo>(op_info_pair.first);
if (op_info_pair.first != nullptr) {
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
recv->set_user_data<OperatorInfo>(op_info_pair.first);
}
manager_->SetEdge(use_node, index, recv);
}
@ -448,13 +492,6 @@ void PipelineTransformer::ElimGraphStage() {
}
}
bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
auto anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(anf_node);
auto prim = anf_node->value()->cast<PrimitivePtr>();
return (prim->name() == name);
}
std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
std::pair<CNodePtr, FuncGraphPtr> sens_graph_pair;
CNodePtr sens_cnode;
@ -471,7 +508,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
}
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) {
if (!IsPrimitiveCNode(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem)) {
continue;
}
auto expect_anonymous = expect_tuple_getitem_cnode->input(1);
@ -484,7 +521,7 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
continue;
}
auto expect_j_cnode = expect_j->cast<CNodePtr>();
if (!IsSomePrimitive(expect_j_cnode, J)) {
if (!IsPrimitiveCNode(expect_j_cnode, prim::kPrimJ)) {
continue;
}
func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));

View File

@ -58,7 +58,6 @@ class PipelineTransformer {
private:
std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users);
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int user_node_stage, int node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
@ -66,6 +65,7 @@ class PipelineTransformer {
void CutBorder(const FuncGraphPtr &graph);
bool IsStageNode(const CNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();

View File

@ -869,6 +869,9 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
return FindParameter(cnode->input(index), func_graph);
}
} else {
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false);
}
if (IsParallelCareNode(cnode)) {
return std::make_pair(nullptr, false);
} else {

View File

@ -56,9 +56,11 @@ class DatasetLenet():
class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2):
def __init__(self, strategy1, strategy2, param=None):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
if param is not None:
self.param = param
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
@ -70,11 +72,11 @@ class MatMulCell(nn.Cell):
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
def __init__(self, strategy1, strategy2, param=None):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2)
cell = MatMulCell(strategy1, strategy2, param)
cell.stage = i
self.block.append(cell)
@ -94,7 +96,33 @@ class PipelineSplit(nn.Cell):
return x
def test_pipeline_split():
class PipelineSplit2(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.cell = Net(strategy1, strategy2, self.param)
def construct(self, x, label):
x = self.cell(x)
return x
def test_pipeline_split_stage0():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit(strategy1, strategy2)
params = net.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_stage1():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
@ -107,3 +135,33 @@ def test_pipeline_split():
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_shared_parameter_stage0():
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2)
params = net.cell.block[0].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
def test_pipeline_split_shared_parameter_stage1():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit2(strategy1, strategy2)
params = net.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)