forked from mindspore-Ecosystem/mindspore
use neighbor-exchange-v2 for conv2d
This commit is contained in:
parent
12fe2ba72e
commit
8a68577756
|
@ -405,10 +405,12 @@ Status Conv2DInfo::InferRankBias() {
|
|||
left_rank_id_ = *(it - 1);
|
||||
right_rank_id_ = *(it + 1);
|
||||
}
|
||||
|
||||
all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily
|
||||
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
|
||||
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
|
||||
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
|
||||
<< ", the right rank id is " << right_rank_id_;
|
||||
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -527,128 +529,73 @@ void Conv2DInfo::InferNewPadList() {
|
|||
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferSendRecvFlag() {
|
||||
if (rank_bias_ == 0) { // the first rank
|
||||
left_need_send_ = false;
|
||||
left_need_recv_ = false;
|
||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||
right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||
left_need_recv_ = (overlap_left_size_ > 0);
|
||||
right_need_send_ = false;
|
||||
right_need_recv_ = false;
|
||||
} else { // the middle rank
|
||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||
left_need_recv_ = (overlap_left_size_ > 0);
|
||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||
right_need_recv_ = (overlap_right_size_ > 0);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
|
||||
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
|
||||
<< right_need_recv_;
|
||||
void Conv2DInfo::InferCommunicationAttrs() {
|
||||
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
|
||||
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
|
||||
// send lens: [0, 0, send_left_len, send_right_len]
|
||||
// recv lens: [0, 0, recv_left_len, recv_right_len]
|
||||
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
|
||||
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;
|
||||
|
||||
if (left_need_send_) {
|
||||
if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
|
||||
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
}
|
||||
send_rank_ids_.push_back(left_rank_id_);
|
||||
if (rank_bias_ == 0) {
|
||||
// the first rank
|
||||
send_right_len = right_rank_overlap_left_size_;
|
||||
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
|
||||
|
||||
recv_right_len = overlap_right_size_;
|
||||
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
|
||||
// the last rank
|
||||
send_left_len = left_rank_overlap_right_size_;
|
||||
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
|
||||
|
||||
recv_left_len = overlap_left_size_;
|
||||
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
|
||||
} else {
|
||||
// the middle rank
|
||||
send_right_len = right_rank_overlap_left_size_;
|
||||
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
|
||||
|
||||
recv_right_len = overlap_right_size_;
|
||||
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
|
||||
send_left_len = left_rank_overlap_right_size_;
|
||||
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
|
||||
|
||||
recv_left_len = overlap_left_size_;
|
||||
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
|
||||
<< ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
|
||||
}
|
||||
send_rank_ids_.push_back(right_rank_id_);
|
||||
}
|
||||
send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
|
||||
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
|
||||
send_lens_ = {0, 0, send_left_len, send_right_len};
|
||||
recv_lens_ = {0, 0, recv_left_len, recv_right_len};
|
||||
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
|
||||
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
|
||||
|
||||
if (left_need_recv_) {
|
||||
recv_rank_ids_.push_back(left_rank_id_);
|
||||
}
|
||||
|
||||
if (right_need_recv_) {
|
||||
recv_rank_ids_.push_back(right_rank_id_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferOverlapShapes() {
|
||||
if (left_need_recv_) {
|
||||
Shape left_recv_shape = input_slice_shape_;
|
||||
left_recv_shape[3] = overlap_left_size_;
|
||||
recv_shapes_.push_back(left_recv_shape);
|
||||
}
|
||||
|
||||
if (right_need_recv_) {
|
||||
Shape right_recv_shape = input_slice_shape_;
|
||||
right_recv_shape[3] = overlap_right_size_;
|
||||
recv_shapes_.push_back(right_recv_shape);
|
||||
}
|
||||
|
||||
if (left_need_send_) {
|
||||
Shape left_send_shape = input_slice_shape_;
|
||||
left_send_shape[3] = left_rank_overlap_right_size_;
|
||||
send_shapes_.push_back(left_send_shape);
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
Shape right_send_shape = input_slice_shape_;
|
||||
right_send_shape[3] = right_rank_overlap_left_size_;
|
||||
send_shapes_.push_back(right_send_shape);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferStridedSliceAttrs() {
|
||||
if (left_need_send_) {
|
||||
left_strided_slice_begin_ = {0, 0, 0, 0};
|
||||
left_strided_slice_end_ = input_slice_shape_;
|
||||
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
|
||||
left_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
|
||||
<< left_strided_slice_end_;
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
right_strided_slice_begin_ = {0, 0, 0, 0};
|
||||
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
|
||||
right_strided_slice_end_ = input_slice_shape_;
|
||||
right_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
|
||||
<< right_strided_slice_end_;
|
||||
int64_t w_slice_shape = input_slice_shape_[3];
|
||||
if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape ||
|
||||
recv_right_len > w_slice_shape) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of w dimension " << w_slice_shape;
|
||||
}
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferNewOperatorAttrs() {
|
||||
InferNewPadList();
|
||||
|
||||
InferSendRecvFlag();
|
||||
|
||||
InferOverlapShapes();
|
||||
|
||||
InferStridedSliceAttrs();
|
||||
InferCommunicationAttrs();
|
||||
}
|
||||
|
||||
OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
|
||||
auto type = cnode->Type();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto dtype = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
|
||||
// the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue
|
||||
OperatorAttrs Conv2DInfo::CreateNeighborExchangeV2Attrs() {
|
||||
// the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue
|
||||
// the MakeValue(vector) return a tuple
|
||||
Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
|
||||
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
|
||||
Attr recv_type = {RECV_TYPE, dtype};
|
||||
Attr group = {GROUP, MakeValue(g_device_manager->world_group())};
|
||||
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type, group};
|
||||
Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
|
||||
Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)};
|
||||
Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
|
||||
Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)};
|
||||
Attr data_format = {DATA_FORMAT, MakeValue(NCHW)};
|
||||
Attr group = {GROUP, MakeValue(all_to_all_group_)};
|
||||
|
||||
OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group};
|
||||
return attrs;
|
||||
}
|
||||
|
||||
|
@ -716,76 +663,13 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
|
||||
}
|
||||
|
||||
if (!left_need_send_ && !right_need_send_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
|
||||
}
|
||||
auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
|
||||
auto neighbor_exchange_v2_node =
|
||||
gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g_.virtual_input_node()});
|
||||
|
||||
if (!left_need_recv_ && !right_need_recv_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
|
||||
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
if (left_need_send_) {
|
||||
auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
|
||||
auto slice_left_end = CreateTuple(left_strided_slice_end_);
|
||||
auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
|
||||
auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
|
||||
slice_left_end, slice_left_strided});
|
||||
make_tuple_a_inputs.push_back(slice_left);
|
||||
input_nodes.push_back(std::make_pair(slice_left, 1));
|
||||
}
|
||||
if (right_need_send_) {
|
||||
auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
|
||||
auto slice_right_end = CreateTuple(right_strided_slice_end_);
|
||||
auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
|
||||
auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
|
||||
slice_right_end, slice_right_strided});
|
||||
make_tuple_a_inputs.push_back(slice_right);
|
||||
input_nodes.push_back(std::make_pair(slice_right, 1));
|
||||
}
|
||||
|
||||
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
|
||||
auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
|
||||
auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
|
||||
|
||||
AnfNodePtr conv2d;
|
||||
Attr concat_axis = {AXIS, MakeValue(-1)};
|
||||
OperatorAttrs concat_attrs = {concat_axis};
|
||||
|
||||
if (left_need_recv_) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(0)};
|
||||
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
|
||||
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
|
||||
cnode->input(1)};
|
||||
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
|
||||
auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
|
||||
|
||||
if (right_need_recv_) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(1)};
|
||||
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
|
||||
std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
|
||||
auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
|
||||
auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
|
||||
conv2d = GenerateConv2DNode(concat_r, cnode);
|
||||
} else {
|
||||
conv2d = GenerateConv2DNode(concat_l, cnode);
|
||||
}
|
||||
} else { // left no need recv, and right need recv
|
||||
std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(0)};
|
||||
auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
|
||||
std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
|
||||
tuple_getitem_r_1};
|
||||
auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
|
||||
input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
|
||||
|
||||
auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
|
||||
conv2d = GenerateConv2DNode(concat_r_1, cnode);
|
||||
}
|
||||
auto conv2d = GenerateConv2DNode(neighbor_exchange_v2_node, cnode);
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)};
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, conv2d));
|
||||
}
|
||||
|
|
|
@ -53,13 +53,11 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Status InferRankBias();
|
||||
void InferOverlapSize();
|
||||
void InferNewOperatorAttrs();
|
||||
void InferSendRecvFlag();
|
||||
void InferOverlapShapes();
|
||||
void InferStridedSliceAttrs();
|
||||
void InferCommunicationAttrs();
|
||||
std::string ReplaceNodeName() const;
|
||||
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
|
||||
OperatorAttrs CreateNeighborExchangeV2Attrs();
|
||||
OperatorAttrs CreateConv2DAttrs();
|
||||
void ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
|
||||
|
@ -91,21 +89,11 @@ class Conv2DInfo : public OperatorInfo {
|
|||
int64_t w_dimension_shard_num_ = 1;
|
||||
Shape input_slice_shape_;
|
||||
|
||||
bool left_need_send_ = false;
|
||||
bool left_need_recv_ = false;
|
||||
bool right_need_send_ = false;
|
||||
bool right_need_recv_ = false;
|
||||
Shape left_strided_slice_begin_;
|
||||
Shape left_strided_slice_end_;
|
||||
Shape left_strided_slice_strides_;
|
||||
Shape right_strided_slice_begin_;
|
||||
Shape right_strided_slice_end_;
|
||||
Shape right_strided_slice_strides_;
|
||||
|
||||
std::vector<int64_t> send_rank_ids_;
|
||||
std::vector<int64_t> recv_rank_ids_;
|
||||
Shapes send_shapes_;
|
||||
Shapes recv_shapes_;
|
||||
std::vector<int64_t> send_lens_;
|
||||
std::vector<int64_t> recv_lens_;
|
||||
std::string all_to_all_group_;
|
||||
|
||||
GenerateGraph gen_g_ = GenerateGraph(attrs_);
|
||||
|
||||
|
|
|
@ -70,6 +70,11 @@ def compile_net(net):
|
|||
|
||||
|
||||
def test_conv2d_transpose_data_parallel():
|
||||
"""
|
||||
Feature: test data parallel strategy
|
||||
Description: only shard batch dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
|
@ -78,6 +83,11 @@ def test_conv2d_transpose_data_parallel():
|
|||
|
||||
|
||||
def test_conv2d_transpose_model_parallel1():
|
||||
"""
|
||||
Feature: test model parallel strategy
|
||||
Description: only shard batch dimension and channel dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
|
||||
strategy2 = ((8, 1, 1, 1),)
|
||||
|
@ -86,6 +96,11 @@ def test_conv2d_transpose_model_parallel1():
|
|||
|
||||
|
||||
def test_conv2d_transpose_model_parallel2():
|
||||
"""
|
||||
Feature: test model parallel strategy
|
||||
Description: shard batch dimension and w dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 4),)
|
||||
|
@ -95,6 +110,11 @@ def test_conv2d_transpose_model_parallel2():
|
|||
|
||||
|
||||
def test_conv2d_transpose_model_parallel3():
|
||||
"""
|
||||
Feature: test model parallel strategy
|
||||
Description: shard batch dimension, channel dimension and w dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
|
@ -104,6 +124,11 @@ def test_conv2d_transpose_model_parallel3():
|
|||
|
||||
|
||||
def test_conv2d_transpose_all_rank_no_need_overlap():
|
||||
"""
|
||||
Feature: test model parallel strategy
|
||||
Description: shard batch dimension, channel dimension and w dimension
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
|
@ -113,6 +138,11 @@ def test_conv2d_transpose_all_rank_no_need_overlap():
|
|||
|
||||
|
||||
def test_conv2d_transpose_split_h_or_w_in_pad_mode():
|
||||
"""
|
||||
Feature: test pad mode
|
||||
Description: shard batch dimension, channel dimension and w dimension in pad mode
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
|
@ -123,6 +153,11 @@ def test_conv2d_transpose_split_h_or_w_in_pad_mode():
|
|||
|
||||
|
||||
def test_conv2d_transpose_split_h_in_same_mode():
|
||||
"""
|
||||
Feature: test split h dimension
|
||||
Description: shard h dimension in same mode
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
|
@ -133,6 +168,11 @@ def test_conv2d_transpose_split_h_in_same_mode():
|
|||
|
||||
|
||||
def test_conv2d_transpose_overlap_size_too_large():
|
||||
"""
|
||||
Feature: test overlap size is too large
|
||||
Description: shard w dimension and overlap size larger than slice shape
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 8),)
|
||||
|
@ -140,24 +180,4 @@ def test_conv2d_transpose_overlap_size_too_large():
|
|||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_overlap_size_too_large2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_rank0_no_need_overlap():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
Loading…
Reference in New Issue