From 8a68577756dba4eac27d5c5a1df5814dad9b4c93 Mon Sep 17 00:00:00 2001
From: yangzhenzhang <yangzhenzhang@huawei.com>
Date: Fri, 17 Dec 2021 14:48:41 +0800
Subject: [PATCH] use neighbor-exchange-v2 for conv2d

---
 .../frontend/parallel/ops_info/conv2d_info.cc | 240 +++++-------------
 .../frontend/parallel/ops_info/conv2d_info.h  |  22 +-
 .../python/parallel/test_conv2d_transpose.py  |  62 +++--
 3 files changed, 108 insertions(+), 216 deletions(-)

diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
index b3bba806704..9efb0f2dc77 100644
--- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
+++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
@@ -405,10 +405,12 @@ Status Conv2DInfo::InferRankBias() {
     left_rank_id_ = *(it - 1);
     right_rank_id_ = *(it + 1);
   }
+
+  all_to_all_group_ = g_device_manager->world_group();  // use world group temporarily
   MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
                << ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
                << ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
-               << ", the right rank id is " << right_rank_id_;
+               << ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
   return SUCCESS;
 }
 
@@ -527,128 +529,73 @@ void Conv2DInfo::InferNewPadList() {
   MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
 }
 
-void Conv2DInfo::InferSendRecvFlag() {
-  if (rank_bias_ == 0) {  // the first rank
-    left_need_send_ = false;
-    left_need_recv_ = false;
-    right_need_send_ = (right_rank_overlap_left_size_ > 0);
-    right_need_recv_ = (overlap_right_size_ > 0);         // no need the right pad
-  } else if (rank_bias_ == w_dimension_shard_num_ - 1) {  // the last rank
-    left_need_send_ = (left_rank_overlap_right_size_ > 0);
-    left_need_recv_ = (overlap_left_size_ > 0);
-    right_need_send_ = false;
-    right_need_recv_ = false;
-  } else {  // the middle rank
-    left_need_send_ = (left_rank_overlap_right_size_ > 0);
-    left_need_recv_ = (overlap_left_size_ > 0);
-    right_need_send_ = (right_rank_overlap_left_size_ > 0);
-    right_need_recv_ = (overlap_right_size_ > 0);
-  }
-  MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
-               << left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
-               << right_need_recv_;
+void Conv2DInfo::InferCommunicationAttrs() {
+  // send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
+  // recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
+  // send lens: [0, 0, send_left_len, send_right_len]
+  // recv lens: [0, 0, recv_left_len, recv_right_len]
+  int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
+  int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;
 
-  if (left_need_send_) {
-    if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
-      MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
-                        << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
-    }
-    send_rank_ids_.push_back(left_rank_id_);
+  if (rank_bias_ == 0) {
+    // the first rank
+    send_right_len = right_rank_overlap_left_size_;
+    send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
+
+    recv_right_len = overlap_right_size_;
+    recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
+  } else if (rank_bias_ == w_dimension_shard_num_ - 1) {
+    // the last rank
+    send_left_len = left_rank_overlap_right_size_;
+    send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
+
+    recv_left_len = overlap_left_size_;
+    recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
+  } else {
+    // the middle rank
+    send_right_len = right_rank_overlap_left_size_;
+    send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
+
+    recv_right_len = overlap_right_size_;
+    recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
+    send_left_len = left_rank_overlap_right_size_;
+    send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
+
+    recv_left_len = overlap_left_size_;
+    recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
   }
 
-  if (right_need_send_) {
-    if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
-      MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
-                        << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
-    }
-    send_rank_ids_.push_back(right_rank_id_);
-  }
+  send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
+  recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
+  send_lens_ = {0, 0, send_left_len, send_right_len};
+  recv_lens_ = {0, 0, recv_left_len, recv_right_len};
+  MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
+               << ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
 
-  if (left_need_recv_) {
-    recv_rank_ids_.push_back(left_rank_id_);
-  }
-
-  if (right_need_recv_) {
-    recv_rank_ids_.push_back(right_rank_id_);
-  }
-
-  MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
-}
-
-void Conv2DInfo::InferOverlapShapes() {
-  if (left_need_recv_) {
-    Shape left_recv_shape = input_slice_shape_;
-    left_recv_shape[3] = overlap_left_size_;
-    recv_shapes_.push_back(left_recv_shape);
-  }
-
-  if (right_need_recv_) {
-    Shape right_recv_shape = input_slice_shape_;
-    right_recv_shape[3] = overlap_right_size_;
-    recv_shapes_.push_back(right_recv_shape);
-  }
-
-  if (left_need_send_) {
-    Shape left_send_shape = input_slice_shape_;
-    left_send_shape[3] = left_rank_overlap_right_size_;
-    send_shapes_.push_back(left_send_shape);
-  }
-
-  if (right_need_send_) {
-    Shape right_send_shape = input_slice_shape_;
-    right_send_shape[3] = right_rank_overlap_left_size_;
-    send_shapes_.push_back(right_send_shape);
-  }
-  MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
-}
-
-void Conv2DInfo::InferStridedSliceAttrs() {
-  if (left_need_send_) {
-    left_strided_slice_begin_ = {0, 0, 0, 0};
-    left_strided_slice_end_ = input_slice_shape_;
-    left_strided_slice_end_[3] = left_rank_overlap_right_size_;
-    left_strided_slice_strides_ = {1, 1, 1, 1};
-    MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
-                 << left_strided_slice_end_;
-  }
-
-  if (right_need_send_) {
-    right_strided_slice_begin_ = {0, 0, 0, 0};
-    right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
-    right_strided_slice_end_ = input_slice_shape_;
-    right_strided_slice_strides_ = {1, 1, 1, 1};
-    MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
-                 << right_strided_slice_end_;
+  int64_t w_slice_shape = input_slice_shape_[3];
+  if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape ||
+      recv_right_len > w_slice_shape) {
+    MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of w dimension " << w_slice_shape;
   }
 }
 
 void Conv2DInfo::InferNewOperatorAttrs() {
   InferNewPadList();
 
-  InferSendRecvFlag();
-
-  InferOverlapShapes();
-
-  InferStridedSliceAttrs();
+  InferCommunicationAttrs();
 }
 
-OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
-  auto type = cnode->Type();
-  MS_EXCEPTION_IF_NULL(type);
-  auto tensor_type = type->cast<mindspore::TensorTypePtr>();
-  MS_EXCEPTION_IF_NULL(tensor_type);
-  auto dtype = tensor_type->element();
-  MS_EXCEPTION_IF_NULL(dtype);
-
-  // the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue
+OperatorAttrs Conv2DInfo::CreateNeighborExchangeV2Attrs() {
+  // the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue
   // the MakeValue(vector) return a tuple
-  Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
-  Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
-  Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
-  Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
-  Attr recv_type = {RECV_TYPE, dtype};
-  Attr group = {GROUP, MakeValue(g_device_manager->world_group())};
-  OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type, group};
+  Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
+  Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)};
+  Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
+  Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)};
+  Attr data_format = {DATA_FORMAT, MakeValue(NCHW)};
+  Attr group = {GROUP, MakeValue(all_to_all_group_)};
+
+  OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group};
   return attrs;
 }
 
@@ -716,76 +663,13 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
     MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
   }
 
-  if (!left_need_send_ && !right_need_send_) {
-    MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
-  }
+  auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
+  auto neighbor_exchange_v2_node =
+    gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g_.virtual_input_node()});
 
-  if (!left_need_recv_ && !right_need_recv_) {
-    MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
-  }
-
-  std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
-  std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
-  if (left_need_send_) {
-    auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
-    auto slice_left_end = CreateTuple(left_strided_slice_end_);
-    auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
-    auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
-                                       slice_left_end, slice_left_strided});
-    make_tuple_a_inputs.push_back(slice_left);
-    input_nodes.push_back(std::make_pair(slice_left, 1));
-  }
-  if (right_need_send_) {
-    auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
-    auto slice_right_end = CreateTuple(right_strided_slice_end_);
-    auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
-    auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
-                                        slice_right_end, slice_right_strided});
-    make_tuple_a_inputs.push_back(slice_right);
-    input_nodes.push_back(std::make_pair(slice_right, 1));
-  }
-
-  auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
-  auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
-  auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
-
-  AnfNodePtr conv2d;
-  Attr concat_axis = {AXIS, MakeValue(-1)};
-  OperatorAttrs concat_attrs = {concat_axis};
-
-  if (left_need_recv_) {
-    std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
-                                                      CreatInt64Imm(0)};
-    auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
-    std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
-                                                   cnode->input(1)};
-    auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
-    auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
-
-    if (right_need_recv_) {
-      std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
-                                                        CreatInt64Imm(1)};
-      auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
-      std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
-      auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
-      auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
-      conv2d = GenerateConv2DNode(concat_r, cnode);
-    } else {
-      conv2d = GenerateConv2DNode(concat_l, cnode);
-    }
-  } else {  // left no need recv, and right need recv
-    std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
-                                                        CreatInt64Imm(0)};
-    auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
-    std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
-                                                     tuple_getitem_r_1};
-    auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
-    input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
-
-    auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
-    conv2d = GenerateConv2DNode(concat_r_1, cnode);
-  }
+  auto conv2d = GenerateConv2DNode(neighbor_exchange_v2_node, cnode);
 
+  std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)};
   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
     std::make_pair(input_nodes, conv2d));
 }
diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
index 920ed2c5598..6267f56adbd 100644
--- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
+++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
@@ -53,13 +53,11 @@ class Conv2DInfo : public OperatorInfo {
   Status InferRankBias();
   void InferOverlapSize();
   void InferNewOperatorAttrs();
-  void InferSendRecvFlag();
-  void InferOverlapShapes();
-  void InferStridedSliceAttrs();
+  void InferCommunicationAttrs();
   std::string ReplaceNodeName() const;
   AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
   ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
-  OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
+  OperatorAttrs CreateNeighborExchangeV2Attrs();
   OperatorAttrs CreateConv2DAttrs();
   void ComputeReplaceGraph(const CNodePtr &cnode);
 
@@ -91,21 +89,11 @@ class Conv2DInfo : public OperatorInfo {
   int64_t w_dimension_shard_num_ = 1;
   Shape input_slice_shape_;
 
-  bool left_need_send_ = false;
-  bool left_need_recv_ = false;
-  bool right_need_send_ = false;
-  bool right_need_recv_ = false;
-  Shape left_strided_slice_begin_;
-  Shape left_strided_slice_end_;
-  Shape left_strided_slice_strides_;
-  Shape right_strided_slice_begin_;
-  Shape right_strided_slice_end_;
-  Shape right_strided_slice_strides_;
-
   std::vector<int64_t> send_rank_ids_;
   std::vector<int64_t> recv_rank_ids_;
-  Shapes send_shapes_;
-  Shapes recv_shapes_;
+  std::vector<int64_t> send_lens_;
+  std::vector<int64_t> recv_lens_;
+  std::string all_to_all_group_;
 
   GenerateGraph gen_g_ = GenerateGraph(attrs_);
 
diff --git a/tests/ut/python/parallel/test_conv2d_transpose.py b/tests/ut/python/parallel/test_conv2d_transpose.py
index fe9111a59d7..eeb09449abb 100644
--- a/tests/ut/python/parallel/test_conv2d_transpose.py
+++ b/tests/ut/python/parallel/test_conv2d_transpose.py
@@ -70,6 +70,11 @@ def compile_net(net):
 
 
 def test_conv2d_transpose_data_parallel():
+    """
+    Feature: test data parallel strategy
+    Description: only shard batch dimension
+    Expectation: compile success
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
     strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
     strategy2 = ((8, 1, 1, 1),)
@@ -78,6 +83,11 @@ def test_conv2d_transpose_data_parallel():
 
 
 def test_conv2d_transpose_model_parallel1():
+    """
+    Feature: test model parallel strategy
+    Description: only shard batch dimension and channel dimension
+    Expectation: compile success
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
     strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
     strategy2 = ((8, 1, 1, 1),)
@@ -86,6 +96,11 @@ def test_conv2d_transpose_model_parallel1():
 
 
 def test_conv2d_transpose_model_parallel2():
+    """
+    Feature: test model parallel strategy
+    Description: shard batch dimension and w dimension
+    Expectation: compile success
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
     strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
     strategy2 = ((2, 1, 1, 4),)
@@ -95,6 +110,11 @@ def test_conv2d_transpose_model_parallel2():
 
 
 def test_conv2d_transpose_model_parallel3():
+    """
+    Feature: test model parallel strategy
+    Description: shard batch dimension, channel dimension and w dimension
+    Expectation: compile success
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
     strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
     strategy2 = ((2, 2, 1, 4),)
@@ -104,6 +124,11 @@ def test_conv2d_transpose_model_parallel3():
 
 
 def test_conv2d_transpose_all_rank_no_need_overlap():
+    """
+    Feature: test model parallel strategy
+    Description: shard batch dimension, channel dimension and w dimension
+    Expectation: compile success
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
     strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
     strategy2 = ((2, 2, 1, 4),)
@@ -113,6 +138,11 @@ def test_conv2d_transpose_all_rank_no_need_overlap():
 
 
 def test_conv2d_transpose_split_h_or_w_in_pad_mode():
+    """
+    Feature: test pad mode
+    Description: shard batch dimension, channel dimension and w dimension in pad mode
+    Expectation: compile failed
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
     strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
     strategy2 = ((2, 2, 1, 4),)
@@ -123,6 +153,11 @@ def test_conv2d_transpose_split_h_or_w_in_pad_mode():
 
 
 def test_conv2d_transpose_split_h_in_same_mode():
+    """
+    Feature: test split h dimension
+    Description: shard h dimension in same mode
+    Expectation: compile failed
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
     strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
     strategy2 = ((2, 2, 1, 4),)
@@ -133,6 +168,11 @@ def test_conv2d_transpose_split_h_in_same_mode():
 
 
 def test_conv2d_transpose_overlap_size_too_large():
+    """
+    Feature: test overlap size is too large
+    Description: shard w dimension and overlap size larger than slice shape
+    Expectation: compile failed
+    """
     context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
     strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
     strategy2 = ((1, 1, 1, 8),)
@@ -140,24 +180,4 @@ def test_conv2d_transpose_overlap_size_too_large():
                strategy1=strategy1, strategy2=strategy2)
     with pytest.raises(RuntimeError):
         compile_net(net)
-
-
-def test_conv2d_transpose_overlap_size_too_large2():
-    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
-    strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
-    strategy2 = ((2, 2, 1, 4),)
-    net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
-               strategy1=strategy1, strategy2=strategy2)
-    with pytest.raises(RuntimeError):
-        compile_net(net)
-
-
-def test_conv2d_transpose_rank0_no_need_overlap():
-    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
-    strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
-    strategy2 = ((2, 2, 1, 4),)
-    net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2,
-               strategy1=strategy1, strategy2=strategy2)
-    with pytest.raises(RuntimeError):
-        compile_net(net)
-    
\ No newline at end of file
+  
\ No newline at end of file