!45073 fix bug for conv2d pad mode

Merge pull request !45073 from yangzhenzhang/fix-bug-for-conv2d-pad-mode
This commit is contained in:
i-robot 2022-11-04 01:26:08 +00:00 committed by Gitee
commit 8f0f05e344
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 20 additions and 2 deletions

View File

@ -146,8 +146,10 @@ void Conv2DInfo::AdjustPadList() {
return;
}
int64_t useless_len_2th_dim = (inputs_shape_[0][2] + pad_list_[0] + pad_list_[1] - kernel_size_[1]) % stride_[2];
int64_t useless_len_3th_dim = (inputs_shape_[0][3] + pad_list_[2] + pad_list_[3] - kernel_size_[2]) % stride_[3];
int64_t useless_len_2th_dim =
(inputs_shape_[0][2] + pad_list_[0] + pad_list_[1] - kernel_size_use_dilation_[0]) % stride_[2];
int64_t useless_len_3th_dim =
(inputs_shape_[0][3] + pad_list_[2] + pad_list_[3] - kernel_size_use_dilation_[1]) % stride_[3];
if (useless_len_2th_dim == 0 && useless_len_3th_dim == 0) {
return;
}

View File

@ -44,12 +44,14 @@ class Net(Cell):
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
_x3 = Tensor(np.ones([32, 16, 16, 16]), dtype=ms.float32)
_x4 = Tensor(np.ones([32, 4, 16, 24]), dtype=ms.float32)
_w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
_w5 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
_w6 = Tensor(np.ones([10, 2, 6, 5]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -103,6 +105,20 @@ def test_conv2d_pad_mode():
compile_net(net, _x3)
def test_conv2d_pad_mode_2():
"""
Feature: test conv2d pad mode and overlap is non-negative
Description: shard w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 1),)
net = Net(_w6, out_channel=10, kernel_size=(6, 5), pad_mode="pad", stride=3, pad=(1, 1, 3, 3), dilation=2, group=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net, _x4)
def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy():
"""
Feature: test conv2d valid mode, and output shape can not div by strategy