!1800 fix cpu StridedSliceGrad bug when different dims between input and output

Merge pull request !1800 from sunsuodong/fix_StrideSliceGrad
This commit is contained in:
mindspore-ci-bot 2020-06-02 20:37:55 +08:00 committed by Gitee
commit 71dce2f586
3 changed files with 34 additions and 7 deletions

View File

@ -61,11 +61,11 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.emplace_back(begin_[i] + sizes[i]); end_.emplace_back(begin_[i] + sizes[i]);
} }
} }
CPUKernelUtils::ExpandDimsTo4(&output_dx_shape_);
auto input_len = input_dy_shape_.size(); auto output_len = output_dx_shape_.size();
if (input_len < 4) { if (output_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) { for (size_t i = 0; i < 4 - output_len; ++i) {
input_dy_shape_.insert(input_dy_shape_.begin(), 1); output_dx_shape_.insert(output_dx_shape_.begin(), 1);
begin_.insert(begin_.begin(), 0); begin_.insert(begin_.begin(), 0);
strides_.insert(strides_.begin(), 1); strides_.insert(strides_.begin(), 1);
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);

View File

@ -19,6 +19,7 @@ import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
@ -38,7 +39,7 @@ class StridedSliceGrad(nn.Cell):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu_training @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_slice(): def test_slice():
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
@ -47,3 +48,29 @@ def test_slice():
output = ssg(dy, x) output = ssg(dy, x)
expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]] expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]]
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
class StridedSliceGrad2(nn.Cell):
def __init__(self):
super(StridedSliceGrad2, self).__init__()
self.ssg = G.StridedSliceGrad()
self.shape = P.Shape()
@ms_function
def construct(self, dy, x):
return self.ssg(dy, self.shape(x), (0, 0, 0), (1, 4, 2), (1, 1, 1))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice2():
x = Tensor(np.arange(2 * 4 * 2).reshape(2, 4, 2), mstype.float32)
dy = Tensor(np.arange(4 * 2).reshape(4, 2), mstype.float32)
ssg = StridedSliceGrad2()
output = ssg(dy, x)
expect = [[[0., 1.], [2., 3.], [4., 5.], [6., 7.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]
assert (output.asnumpy() == expect).all()
if __name__ == '__main__':
test_slice()
test_slice2()

View File

@ -34,7 +34,7 @@ class StridedSlice(nn.Cell):
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_cpu_training @pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_slice(): def test_slice():
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32))