!45073 fix bug for conv2d pad mode
Merge pull request !45073 from yangzhenzhang/fix-bug-for-conv2d-pad-mode
This commit is contained in:
commit
8f0f05e344
|
@ -146,8 +146,10 @@ void Conv2DInfo::AdjustPadList() {
|
|||
return;
|
||||
}
|
||||
|
||||
int64_t useless_len_2th_dim = (inputs_shape_[0][2] + pad_list_[0] + pad_list_[1] - kernel_size_[1]) % stride_[2];
|
||||
int64_t useless_len_3th_dim = (inputs_shape_[0][3] + pad_list_[2] + pad_list_[3] - kernel_size_[2]) % stride_[3];
|
||||
int64_t useless_len_2th_dim =
|
||||
(inputs_shape_[0][2] + pad_list_[0] + pad_list_[1] - kernel_size_use_dilation_[0]) % stride_[2];
|
||||
int64_t useless_len_3th_dim =
|
||||
(inputs_shape_[0][3] + pad_list_[2] + pad_list_[3] - kernel_size_use_dilation_[1]) % stride_[3];
|
||||
if (useless_len_2th_dim == 0 && useless_len_3th_dim == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -44,12 +44,14 @@ class Net(Cell):
|
|||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
|
||||
_x3 = Tensor(np.ones([32, 16, 16, 16]), dtype=ms.float32)
|
||||
_x4 = Tensor(np.ones([32, 4, 16, 24]), dtype=ms.float32)
|
||||
_w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
|
||||
_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
|
||||
_w5 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
|
||||
_w6 = Tensor(np.ones([10, 2, 6, 5]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -103,6 +105,20 @@ def test_conv2d_pad_mode():
|
|||
compile_net(net, _x3)
|
||||
|
||||
|
||||
def test_conv2d_pad_mode_2():
|
||||
"""
|
||||
Feature: test conv2d pad mode and overlap is non-negative
|
||||
Description: shard w
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 1),)
|
||||
net = Net(_w6, out_channel=10, kernel_size=(6, 5), pad_mode="pad", stride=3, pad=(1, 1, 3, 3), dilation=2, group=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net, _x4)
|
||||
|
||||
|
||||
def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy():
|
||||
"""
|
||||
Feature: test conv2d valid mode, and output shape can not div by strategy
|
||||
|
|
Loading…
Reference in New Issue