diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index 94db178195e..04c546432c0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -302,12 +302,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) { } } - // if the h/w dimension is split, need to exchange overlap - if (input_strategy[2] > 1) { + // if the h/w dimension is split, and the pad mode is not "valid", need to exchange overlap + if (input_strategy[2] > 1 && pad_mode_ != 2) { h_dim_need_exchange_overlap_ = true; } - if (input_strategy[3] > 1) { + if (input_strategy[3] > 1 && pad_mode_ != 2) { w_dim_need_exchange_overlap_ = true; } return SUCCESS; @@ -556,8 +556,12 @@ void Conv2DInfo::InferOverlapSizeForWDim() { } } -void Conv2DInfo::CheckOverlapSizeNonNegative() { - // check h dimension +void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() { + if (h_dimension_shard_num_ == 1) { + MS_LOG(INFO) << name_ << ": The h dimension is not shard"; + return; + } + int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0); if (h_first_rank_bottom_size < 0) { MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is " @@ -579,8 +583,13 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() { MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is " << h_last_rank_top_size; } +} - // check w dimension +void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() { + if (w_dimension_shard_num_ == 1) { + MS_LOG(INFO) << name_ << ": The w dimension is not shard"; + return; + } int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0); if (w_first_rank_right_size < 0) { MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is " @@ -604,6 +613,11 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() { } } +void Conv2DInfo::CheckOverlapSizeNonNegative() { + CheckHDimensionOverlapSizeNonNegative(); + CheckWDimensionOverlapSizeNonNegative(); +} + void Conv2DInfo::InferOverlapSize() { InferOverlapSizeForHDim(); InferOverlapSizeForWDim(); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h index 445c167aab6..6bbd4b2f3a7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h @@ -53,6 +53,8 @@ class Conv2DInfo : public OperatorInfo { void InferAdjacentRankInfo(); std::vector GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension); void InferOverlapSize(); + void CheckHDimensionOverlapSizeNonNegative(); + void CheckWDimensionOverlapSizeNonNegative(); void CheckOverlapSizeNonNegative(); void InferOverlapSizeForHDim(); void InferOverlapSizeForWDim(); diff --git a/tests/ut/python/parallel/test_conv2d.py b/tests/ut/python/parallel/test_conv2d.py index 6f91d3e27a9..980a31ceede 100644 --- a/tests/ut/python/parallel/test_conv2d.py +++ b/tests/ut/python/parallel/test_conv2d.py @@ -100,6 +100,21 @@ def test_conv2d_pad_mode(): compile_net(net, _x3) +def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy(): + """ + Feature: test conv2d valid mode, and output shape can not div by strategy + Description: shard w + Expectation: compile failed + """ + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) + strategy2 = ((1, 1, 1, 1),) + net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="valid", stride=4, + strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net, _x3) + + def test_conv2d_data_parallel_invalid_stride(): """ Feature: test conv2d invalid stride