use neighbor-exchange-v2 for conv2d

This commit is contained in:
yangzhenzhang 2021-12-17 14:48:41 +08:00
parent 12fe2ba72e
commit 8a68577756
3 changed files with 108 additions and 216 deletions

View File

@ -405,10 +405,12 @@ Status Conv2DInfo::InferRankBias() {
left_rank_id_ = *(it - 1);
right_rank_id_ = *(it + 1);
}
all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
<< ", the right rank id is " << right_rank_id_;
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
return SUCCESS;
}
@ -527,128 +529,73 @@ void Conv2DInfo::InferNewPadList() {
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
}
void Conv2DInfo::InferSendRecvFlag() {
if (rank_bias_ == 0) { // the first rank
left_need_send_ = false;
left_need_recv_ = false;
right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = false;
right_need_recv_ = false;
} else { // the middle rank
left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0);
}
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
<< right_need_recv_;
void Conv2DInfo::InferCommunicationAttrs() {
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
// send lens: [0, 0, send_left_len, send_right_len]
// recv lens: [0, 0, recv_left_len, recv_right_len]
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;
if (left_need_send_) {
if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(left_rank_id_);
if (rank_bias_ == 0) {
// the first rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
// the last rank
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
} else {
// the middle rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
}
if (right_need_send_) {
if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
}
send_rank_ids_.push_back(right_rank_id_);
}
send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
send_lens_ = {0, 0, send_left_len, send_right_len};
recv_lens_ = {0, 0, recv_left_len, recv_right_len};
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
if (left_need_recv_) {
recv_rank_ids_.push_back(left_rank_id_);
}
if (right_need_recv_) {
recv_rank_ids_.push_back(right_rank_id_);
}
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
}
void Conv2DInfo::InferOverlapShapes() {
if (left_need_recv_) {
Shape left_recv_shape = input_slice_shape_;
left_recv_shape[3] = overlap_left_size_;
recv_shapes_.push_back(left_recv_shape);
}
if (right_need_recv_) {
Shape right_recv_shape = input_slice_shape_;
right_recv_shape[3] = overlap_right_size_;
recv_shapes_.push_back(right_recv_shape);
}
if (left_need_send_) {
Shape left_send_shape = input_slice_shape_;
left_send_shape[3] = left_rank_overlap_right_size_;
send_shapes_.push_back(left_send_shape);
}
if (right_need_send_) {
Shape right_send_shape = input_slice_shape_;
right_send_shape[3] = right_rank_overlap_left_size_;
send_shapes_.push_back(right_send_shape);
}
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
}
void Conv2DInfo::InferStridedSliceAttrs() {
if (left_need_send_) {
left_strided_slice_begin_ = {0, 0, 0, 0};
left_strided_slice_end_ = input_slice_shape_;
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
left_strided_slice_strides_ = {1, 1, 1, 1};
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
<< left_strided_slice_end_;
}
if (right_need_send_) {
right_strided_slice_begin_ = {0, 0, 0, 0};
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
right_strided_slice_end_ = input_slice_shape_;
right_strided_slice_strides_ = {1, 1, 1, 1};
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
<< right_strided_slice_end_;
int64_t w_slice_shape = input_slice_shape_[3];
if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape ||
recv_right_len > w_slice_shape) {
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of w dimension " << w_slice_shape;
}
}
void Conv2DInfo::InferNewOperatorAttrs() {
InferNewPadList();
InferSendRecvFlag();
InferOverlapShapes();
InferStridedSliceAttrs();
InferCommunicationAttrs();
}
OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
auto type = cnode->Type();
MS_EXCEPTION_IF_NULL(type);
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto dtype = tensor_type->element();
MS_EXCEPTION_IF_NULL(dtype);
// the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue
OperatorAttrs Conv2DInfo::CreateNeighborExchangeV2Attrs() {
// the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue
// the MakeValue(vector) return a tuple
Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
Attr recv_type = {RECV_TYPE, dtype};
Attr group = {GROUP, MakeValue(g_device_manager->world_group())};
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type, group};
Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)};
Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)};
Attr data_format = {DATA_FORMAT, MakeValue(NCHW)};
Attr group = {GROUP, MakeValue(all_to_all_group_)};
OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group};
return attrs;
}
@ -716,76 +663,13 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
}
if (!left_need_send_ && !right_need_send_) {
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
}
auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
auto neighbor_exchange_v2_node =
gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g_.virtual_input_node()});
if (!left_need_recv_ && !right_need_recv_) {
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
}
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
if (left_need_send_) {
auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
auto slice_left_end = CreateTuple(left_strided_slice_end_);
auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
slice_left_end, slice_left_strided});
make_tuple_a_inputs.push_back(slice_left);
input_nodes.push_back(std::make_pair(slice_left, 1));
}
if (right_need_send_) {
auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
auto slice_right_end = CreateTuple(right_strided_slice_end_);
auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
slice_right_end, slice_right_strided});
make_tuple_a_inputs.push_back(slice_right);
input_nodes.push_back(std::make_pair(slice_right, 1));
}
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
AnfNodePtr conv2d;
Attr concat_axis = {AXIS, MakeValue(-1)};
OperatorAttrs concat_attrs = {concat_axis};
if (left_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
cnode->input(1)};
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
if (right_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(1)};
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
conv2d = GenerateConv2DNode(concat_r, cnode);
} else {
conv2d = GenerateConv2DNode(concat_l, cnode);
}
} else { // left no need recv, and right need recv
std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
tuple_getitem_r_1};
auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
conv2d = GenerateConv2DNode(concat_r_1, cnode);
}
auto conv2d = GenerateConv2DNode(neighbor_exchange_v2_node, cnode);
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, conv2d));
}

View File

@ -53,13 +53,11 @@ class Conv2DInfo : public OperatorInfo {
Status InferRankBias();
void InferOverlapSize();
void InferNewOperatorAttrs();
void InferSendRecvFlag();
void InferOverlapShapes();
void InferStridedSliceAttrs();
void InferCommunicationAttrs();
std::string ReplaceNodeName() const;
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
OperatorAttrs CreateNeighborExchangeV2Attrs();
OperatorAttrs CreateConv2DAttrs();
void ComputeReplaceGraph(const CNodePtr &cnode);
@ -91,21 +89,11 @@ class Conv2DInfo : public OperatorInfo {
int64_t w_dimension_shard_num_ = 1;
Shape input_slice_shape_;
bool left_need_send_ = false;
bool left_need_recv_ = false;
bool right_need_send_ = false;
bool right_need_recv_ = false;
Shape left_strided_slice_begin_;
Shape left_strided_slice_end_;
Shape left_strided_slice_strides_;
Shape right_strided_slice_begin_;
Shape right_strided_slice_end_;
Shape right_strided_slice_strides_;
std::vector<int64_t> send_rank_ids_;
std::vector<int64_t> recv_rank_ids_;
Shapes send_shapes_;
Shapes recv_shapes_;
std::vector<int64_t> send_lens_;
std::vector<int64_t> recv_lens_;
std::string all_to_all_group_;
GenerateGraph gen_g_ = GenerateGraph(attrs_);

View File

@ -70,6 +70,11 @@ def compile_net(net):
def test_conv2d_transpose_data_parallel():
"""
Feature: test data parallel strategy
Description: only shard batch dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@ -78,6 +83,11 @@ def test_conv2d_transpose_data_parallel():
def test_conv2d_transpose_model_parallel1():
"""
Feature: test model parallel strategy
Description: only shard batch dimension and channel dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@ -86,6 +96,11 @@ def test_conv2d_transpose_model_parallel1():
def test_conv2d_transpose_model_parallel2():
"""
Feature: test model parallel strategy
Description: shard batch dimension and w dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
@ -95,6 +110,11 @@ def test_conv2d_transpose_model_parallel2():
def test_conv2d_transpose_model_parallel3():
"""
Feature: test model parallel strategy
Description: shard batch dimension, channel dimension and w dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
@ -104,6 +124,11 @@ def test_conv2d_transpose_model_parallel3():
def test_conv2d_transpose_all_rank_no_need_overlap():
"""
Feature: test model parallel strategy
Description: shard batch dimension, channel dimension and w dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
@ -113,6 +138,11 @@ def test_conv2d_transpose_all_rank_no_need_overlap():
def test_conv2d_transpose_split_h_or_w_in_pad_mode():
"""
Feature: test pad mode
Description: shard batch dimension, channel dimension and w dimension in pad mode
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
@ -123,6 +153,11 @@ def test_conv2d_transpose_split_h_or_w_in_pad_mode():
def test_conv2d_transpose_split_h_in_same_mode():
"""
Feature: test split h dimension
Description: shard h dimension in same mode
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
@ -133,6 +168,11 @@ def test_conv2d_transpose_split_h_in_same_mode():
def test_conv2d_transpose_overlap_size_too_large():
"""
Feature: test overlap size is too large
Description: shard w dimension and overlap size larger than slice shape
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@ -141,23 +181,3 @@ def test_conv2d_transpose_overlap_size_too_large():
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_overlap_size_too_large2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_transpose_rank0_no_need_overlap():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)