diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index 092e63f15ae..3c8c06e5931 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -103,10 +103,6 @@ Status Conv2DInfo::GetAttrsBase() { // group group_ = GetIntAttr(GROUP); - if (group_ != 1) { - MS_LOG(ERROR) << name_ << ": The group must be 1, but got " << group_; - return FAILED; - } // format format_ = GetStringAttr(FORMAT); @@ -176,10 +172,10 @@ Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strateg return FAILED; } - if (w_slice_shape < ((kernel_size_[1] - stride_[3] + 1) / 2)) { + if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) { MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is " - "smaller than (k - s + 1) / 2"; + "smaller than or equal to (k - s + 1) / 2"; return FAILED; } @@ -275,6 +271,23 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { new_out_channel_ = out_channel_; } + int64_t input_except_n_shards = + std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies()); + int64_t weight_shards = + std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies()); + + bool is_data_parallel = (input_except_n_shards * weight_shards == 1); + if (!is_data_parallel) { + if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) { + MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_; + return FAILED; + } + + if (group_ != 1) { + MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_; + return FAILED; + } + } return SUCCESS; } @@ -536,17 +549,17 @@ void Conv2DInfo::InferSendRecvFlag() { << right_need_recv_; if (left_need_send_) { - if (left_rank_overlap_right_size_ > input_slice_shape_[3]) { + if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) { MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_ - << ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")"; + << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")"; } send_rank_ids_.push_back(left_rank_id_); } if (right_need_send_) { - if (right_rank_overlap_left_size_ > input_slice_shape_[3]) { + if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) { MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_ - << ") larger than slice shape in w dimension(" << input_slice_shape_[3] << ")"; + << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")"; } send_rank_ids_.push_back(right_rank_id_); } @@ -862,8 +875,8 @@ Status Conv2DBackpropInputInfo::GetOutShape() { for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { - int64_t axis = element->cast()->value(); - out_shape_.push_back(axis); + int64_t ele_value = element->cast()->value(); + out_shape_.push_back(ele_value); } else { MS_LOG(ERROR) << name_ << ": The value of shape must be int"; return FAILED; diff --git a/tests/ut/python/parallel/test_conv2d.py b/tests/ut/python/parallel/test_conv2d.py index 08086e030bb..34e63d04e7b 100644 --- a/tests/ut/python/parallel/test_conv2d.py +++ b/tests/ut/python/parallel/test_conv2d.py @@ -23,11 +23,11 @@ from mindspore.ops import operations as P class Net(Cell): - def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, strategy1=None, strategy2=None): super().__init__() self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, - pad_mode=pad_mode, stride=stride).shard(strategy1) + pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1) self.neg = P.Neg().shard(strategy2) self.conv2d_weight = Parameter(conv2d_weight, "w1") @@ -43,6 +43,7 @@ _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32) _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32) +_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32) _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) @@ -63,6 +64,24 @@ def test_conv2d_data_parallel(): compile_net(net) +def test_conv2d_data_parallel_dilation(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_conv2d_data_parallel_group(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + def test_conv2d_model_parallel1(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) @@ -71,6 +90,26 @@ def test_conv2d_model_parallel1(): compile_net(net) +def test_conv2d_model_parallel_dilation(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_conv2d_model_parallel_group(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) + strategy2 = ((8, 1, 1, 1),) + net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_conv2d_model_parallel2(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1)) @@ -182,6 +221,15 @@ def test_kernel_size_larger_than_stride_and_slice_too_small(): compile_net(net) +def test_conv2d_same_mode_overlap_size_equal_to_slice_shape(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) + strategy2 = ((2, 1, 1, 4),) + net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_kernel_size_larger_than_stride_and_left_pad_is_0(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1)) diff --git a/tests/ut/python/parallel/test_conv2d_transpose.py b/tests/ut/python/parallel/test_conv2d_transpose.py index 9e6316d4ca5..bd2cd2d32f1 100644 --- a/tests/ut/python/parallel/test_conv2d_transpose.py +++ b/tests/ut/python/parallel/test_conv2d_transpose.py @@ -122,6 +122,16 @@ def test_conv2d_transpose_overlap_size_too_large(): compile_net(net) +def test_conv2d_transpose_overlap_size_too_large2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) + strategy2 = ((2, 2, 1, 4),) + net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_conv2d_transpose_rank0_no_need_overlap(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))