From 8a68577756dba4eac27d5c5a1df5814dad9b4c93 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <yangzhenzhang@huawei.com> Date: Fri, 17 Dec 2021 14:48:41 +0800 Subject: [PATCH] use neighbor-exchange-v2 for conv2d --- .../frontend/parallel/ops_info/conv2d_info.cc | 240 +++++------------- .../frontend/parallel/ops_info/conv2d_info.h | 22 +- .../python/parallel/test_conv2d_transpose.py | 62 +++-- 3 files changed, 108 insertions(+), 216 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index b3bba806704..9efb0f2dc77 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -405,10 +405,12 @@ Status Conv2DInfo::InferRankBias() { left_rank_id_ = *(it - 1); right_rank_id_ = *(it + 1); } + + all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices << ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_ << ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_ - << ", the right rank id is " << right_rank_id_; + << ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_; return SUCCESS; } @@ -527,128 +529,73 @@ void Conv2DInfo::InferNewPadList() { MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_; } -void Conv2DInfo::InferSendRecvFlag() { - if (rank_bias_ == 0) { // the first rank - left_need_send_ = false; - left_need_recv_ = false; - right_need_send_ = (right_rank_overlap_left_size_ > 0); - right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad - } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank - left_need_send_ = (left_rank_overlap_right_size_ > 0); - left_need_recv_ = (overlap_left_size_ > 0); - right_need_send_ = false; - right_need_recv_ = false; - } else { // the middle rank - left_need_send_ = (left_rank_overlap_right_size_ > 0); - left_need_recv_ = (overlap_left_size_ > 0); - right_need_send_ = (right_rank_overlap_left_size_ > 0); - right_need_recv_ = (overlap_right_size_ > 0); - } - MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is " - << left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is " - << right_need_recv_; +void Conv2DInfo::InferCommunicationAttrs() { + // send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1] + // recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1] + // send lens: [0, 0, send_left_len, send_right_len] + // recv lens: [0, 0, recv_left_len, recv_right_len] + int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1; + int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0; - if (left_need_send_) { - if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) { - MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_ - << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")"; - } - send_rank_ids_.push_back(left_rank_id_); + if (rank_bias_ == 0) { + // the first rank + send_right_len = right_rank_overlap_left_size_; + send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; + + recv_right_len = overlap_right_size_; + recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; + } else if (rank_bias_ == w_dimension_shard_num_ - 1) { + // the last rank + send_left_len = left_rank_overlap_right_size_; + send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; + + recv_left_len = overlap_left_size_; + recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; + } else { + // the middle rank + send_right_len = right_rank_overlap_left_size_; + send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; + + recv_right_len = overlap_right_size_; + recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; + send_left_len = left_rank_overlap_right_size_; + send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; + + recv_left_len = overlap_left_size_; + recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; } - if (right_need_send_) { - if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) { - MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_ - << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")"; - } - send_rank_ids_.push_back(right_rank_id_); - } + send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1}; + recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1}; + send_lens_ = {0, 0, send_left_len, send_right_len}; + recv_lens_ = {0, 0, recv_left_len, recv_right_len}; + MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_ + << ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_; - if (left_need_recv_) { - recv_rank_ids_.push_back(left_rank_id_); - } - - if (right_need_recv_) { - recv_rank_ids_.push_back(right_rank_id_); - } - - MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_; -} - -void Conv2DInfo::InferOverlapShapes() { - if (left_need_recv_) { - Shape left_recv_shape = input_slice_shape_; - left_recv_shape[3] = overlap_left_size_; - recv_shapes_.push_back(left_recv_shape); - } - - if (right_need_recv_) { - Shape right_recv_shape = input_slice_shape_; - right_recv_shape[3] = overlap_right_size_; - recv_shapes_.push_back(right_recv_shape); - } - - if (left_need_send_) { - Shape left_send_shape = input_slice_shape_; - left_send_shape[3] = left_rank_overlap_right_size_; - send_shapes_.push_back(left_send_shape); - } - - if (right_need_send_) { - Shape right_send_shape = input_slice_shape_; - right_send_shape[3] = right_rank_overlap_left_size_; - send_shapes_.push_back(right_send_shape); - } - MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_; -} - -void Conv2DInfo::InferStridedSliceAttrs() { - if (left_need_send_) { - left_strided_slice_begin_ = {0, 0, 0, 0}; - left_strided_slice_end_ = input_slice_shape_; - left_strided_slice_end_[3] = left_rank_overlap_right_size_; - left_strided_slice_strides_ = {1, 1, 1, 1}; - MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is " - << left_strided_slice_end_; - } - - if (right_need_send_) { - right_strided_slice_begin_ = {0, 0, 0, 0}; - right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_; - right_strided_slice_end_ = input_slice_shape_; - right_strided_slice_strides_ = {1, 1, 1, 1}; - MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is " - << right_strided_slice_end_; + int64_t w_slice_shape = input_slice_shape_[3]; + if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape || + recv_right_len > w_slice_shape) { + MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of w dimension " << w_slice_shape; } } void Conv2DInfo::InferNewOperatorAttrs() { InferNewPadList(); - InferSendRecvFlag(); - - InferOverlapShapes(); - - InferStridedSliceAttrs(); + InferCommunicationAttrs(); } -OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) { - auto type = cnode->Type(); - MS_EXCEPTION_IF_NULL(type); - auto tensor_type = type->cast<mindspore::TensorTypePtr>(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto dtype = tensor_type->element(); - MS_EXCEPTION_IF_NULL(dtype); - - // the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue +OperatorAttrs Conv2DInfo::CreateNeighborExchangeV2Attrs() { + // the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue // the MakeValue(vector) return a tuple - Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)}; - Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)}; - Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)}; - Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)}; - Attr recv_type = {RECV_TYPE, dtype}; - Attr group = {GROUP, MakeValue(g_device_manager->world_group())}; - OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type, group}; + Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)}; + Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)}; + Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)}; + Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)}; + Attr data_format = {DATA_FORMAT, MakeValue(NCHW)}; + Attr group = {GROUP, MakeValue(all_to_all_group_)}; + + OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group}; return attrs; } @@ -716,76 +663,13 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { MS_LOG(EXCEPTION) << "GenerateGraph Init failed"; } - if (!left_need_send_ && !right_need_send_) { - MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send"; - } + auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs(); + auto neighbor_exchange_v2_node = + gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g_.virtual_input_node()}); - if (!left_need_recv_ && !right_need_recv_) { - MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv"; - } - - std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes; - std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - if (left_need_send_) { - auto slice_left_begin = CreateTuple(left_strided_slice_begin_); - auto slice_left_end = CreateTuple(left_strided_slice_end_); - auto slice_left_strided = CreateTuple(left_strided_slice_strides_); - auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin, - slice_left_end, slice_left_strided}); - make_tuple_a_inputs.push_back(slice_left); - input_nodes.push_back(std::make_pair(slice_left, 1)); - } - if (right_need_send_) { - auto slice_right_begin = CreateTuple(right_strided_slice_begin_); - auto slice_right_end = CreateTuple(right_strided_slice_end_); - auto slice_right_strided = CreateTuple(right_strided_slice_strides_); - auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin, - slice_right_end, slice_right_strided}); - make_tuple_a_inputs.push_back(slice_right); - input_nodes.push_back(std::make_pair(slice_right, 1)); - } - - auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs); - auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode); - auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); - - AnfNodePtr conv2d; - Attr concat_axis = {AXIS, MakeValue(-1)}; - OperatorAttrs concat_attrs = {concat_axis}; - - if (left_need_recv_) { - std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, - CreatInt64Imm(0)}; - auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs); - std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l, - cnode->input(1)}; - auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs); - auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l}); - - if (right_need_recv_) { - std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, - CreatInt64Imm(1)}; - auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); - std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r}; - auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs); - auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r}); - conv2d = GenerateConv2DNode(concat_r, cnode); - } else { - conv2d = GenerateConv2DNode(concat_l, cnode); - } - } else { // left no need recv, and right need recv - std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, - CreatInt64Imm(0)}; - auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1); - std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(), - tuple_getitem_r_1}; - auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1); - input_nodes.push_back(std::make_pair(make_tuple_r_1, 1)); - - auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1}); - conv2d = GenerateConv2DNode(concat_r_1, cnode); - } + auto conv2d = GenerateConv2DNode(neighbor_exchange_v2_node, cnode); + std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)}; replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( std::make_pair(input_nodes, conv2d)); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h index 920ed2c5598..6267f56adbd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h @@ -53,13 +53,11 @@ class Conv2DInfo : public OperatorInfo { Status InferRankBias(); void InferOverlapSize(); void InferNewOperatorAttrs(); - void InferSendRecvFlag(); - void InferOverlapShapes(); - void InferStridedSliceAttrs(); + void InferCommunicationAttrs(); std::string ReplaceNodeName() const; AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode); ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode); + OperatorAttrs CreateNeighborExchangeV2Attrs(); OperatorAttrs CreateConv2DAttrs(); void ComputeReplaceGraph(const CNodePtr &cnode); @@ -91,21 +89,11 @@ class Conv2DInfo : public OperatorInfo { int64_t w_dimension_shard_num_ = 1; Shape input_slice_shape_; - bool left_need_send_ = false; - bool left_need_recv_ = false; - bool right_need_send_ = false; - bool right_need_recv_ = false; - Shape left_strided_slice_begin_; - Shape left_strided_slice_end_; - Shape left_strided_slice_strides_; - Shape right_strided_slice_begin_; - Shape right_strided_slice_end_; - Shape right_strided_slice_strides_; - std::vector<int64_t> send_rank_ids_; std::vector<int64_t> recv_rank_ids_; - Shapes send_shapes_; - Shapes recv_shapes_; + std::vector<int64_t> send_lens_; + std::vector<int64_t> recv_lens_; + std::string all_to_all_group_; GenerateGraph gen_g_ = GenerateGraph(attrs_); diff --git a/tests/ut/python/parallel/test_conv2d_transpose.py b/tests/ut/python/parallel/test_conv2d_transpose.py index fe9111a59d7..eeb09449abb 100644 --- a/tests/ut/python/parallel/test_conv2d_transpose.py +++ b/tests/ut/python/parallel/test_conv2d_transpose.py @@ -70,6 +70,11 @@ def compile_net(net): def test_conv2d_transpose_data_parallel(): + """ + Feature: test data parallel strategy + Description: only shard batch dimension + Expectation: compile success + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) strategy2 = ((8, 1, 1, 1),) @@ -78,6 +83,11 @@ def test_conv2d_transpose_data_parallel(): def test_conv2d_transpose_model_parallel1(): + """ + Feature: test model parallel strategy + Description: only shard batch dimension and channel dimension + Expectation: compile success + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) strategy2 = ((8, 1, 1, 1),) @@ -86,6 +96,11 @@ def test_conv2d_transpose_model_parallel1(): def test_conv2d_transpose_model_parallel2(): + """ + Feature: test model parallel strategy + Description: shard batch dimension and w dimension + Expectation: compile success + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) strategy2 = ((2, 1, 1, 4),) @@ -95,6 +110,11 @@ def test_conv2d_transpose_model_parallel2(): def test_conv2d_transpose_model_parallel3(): + """ + Feature: test model parallel strategy + Description: shard batch dimension, channel dimension and w dimension + Expectation: compile success + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) strategy2 = ((2, 2, 1, 4),) @@ -104,6 +124,11 @@ def test_conv2d_transpose_model_parallel3(): def test_conv2d_transpose_all_rank_no_need_overlap(): + """ + Feature: test model parallel strategy + Description: shard batch dimension, channel dimension and w dimension + Expectation: compile success + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) strategy2 = ((2, 2, 1, 4),) @@ -113,6 +138,11 @@ def test_conv2d_transpose_all_rank_no_need_overlap(): def test_conv2d_transpose_split_h_or_w_in_pad_mode(): + """ + Feature: test pad mode + Description: shard batch dimension, channel dimension and w dimension in pad mode + Expectation: compile failed + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) strategy2 = ((2, 2, 1, 4),) @@ -123,6 +153,11 @@ def test_conv2d_transpose_split_h_or_w_in_pad_mode(): def test_conv2d_transpose_split_h_in_same_mode(): + """ + Feature: test split h dimension + Description: shard h dimension in same mode + Expectation: compile failed + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1)) strategy2 = ((2, 2, 1, 4),) @@ -133,6 +168,11 @@ def test_conv2d_transpose_split_h_in_same_mode(): def test_conv2d_transpose_overlap_size_too_large(): + """ + Feature: test overlap size is too large + Description: shard w dimension and overlap size larger than slice shape + Expectation: compile failed + """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) strategy2 = ((1, 1, 1, 8),) @@ -140,24 +180,4 @@ def test_conv2d_transpose_overlap_size_too_large(): strategy1=strategy1, strategy2=strategy2) with pytest.raises(RuntimeError): compile_net(net) - - -def test_conv2d_transpose_overlap_size_too_large2(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) - strategy2 = ((2, 2, 1, 4),) - net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, - strategy1=strategy1, strategy2=strategy2) - with pytest.raises(RuntimeError): - compile_net(net) - - -def test_conv2d_transpose_rank0_no_need_overlap(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) - strategy2 = ((2, 2, 1, 4),) - net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2, - strategy1=strategy1, strategy2=strategy2) - with pytest.raises(RuntimeError): - compile_net(net) - \ No newline at end of file + \ No newline at end of file