From de0b4d089f93b364ea145b0b812f147598b31f92 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Fri, 9 Oct 2020 17:26:12 +0800 Subject: [PATCH] fix cpu slice --- .../kernel_compiler/cpu/slice_cpu_kernel.cc | 17 +++++------ tests/st/ops/cpu/test_slice_op.py | 28 +++++++++++++++++-- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc index 0adbcc61701..d2708832e9a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc @@ -21,7 +21,6 @@ namespace kernel { void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); MS_EXCEPTION_IF_NULL(prim); @@ -65,12 +64,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { } void SliceCPUKernel::ExpandAllMemberDims() { - auto output_len = output_shape_.size(); - if (output_len < 4) { - for (size_t i = 0; i < 4 - output_len; ++i) { - output_shape_.push_back(1); - } - } auto input_len = input_shape_.size(); if (input_len < 4) { for (size_t i = 0; i < 4 - input_len; ++i) { @@ -80,6 +73,15 @@ void SliceCPUKernel::ExpandAllMemberDims() { end_.insert(end_.begin(), 1); } } + for (size_t i = 0; i < 4; ++i) { + if (SignOfStride(i)) { + int ax = (end_[i] - begin_[i]) * SignOfStride(i); + if (ax < 0) { + ax = 0; + } + output_shape_.push_back(IntToSize(ax)); + } + } } bool SliceCPUKernel::Launch(const std::vector &inputs, @@ -87,7 +89,6 @@ bool SliceCPUKernel::Launch(const std::vector &inputs, const std::vector &outputs) { auto input_addr = reinterpret_cast(inputs[0]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); - bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)}; size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], diff --git a/tests/st/ops/cpu/test_slice_op.py b/tests/st/ops/cpu/test_slice_op.py index 4af5690d774..d13a5ec5555 100644 --- a/tests/st/ops/cpu/test_slice_op.py +++ b/tests/st/ops/cpu/test_slice_op.py @@ -45,7 +45,6 @@ def test_slice(): slice_op = Slice() output = slice_op(x) - print("output:\n", output) assert (output.asnumpy() == expect).all() @@ -68,7 +67,6 @@ def test_slice2(): slice_op = Slice2() output = slice_op(x) - print("output:\n", output) assert (output.asnumpy() == expect).all() @@ -115,8 +113,34 @@ def test_slice4(): assert (output.asnumpy() == inputx[:10:1, :, 2:3:1]).all() +class Slice5(nn.Cell): + def __init__(self, begin, size): + super(Slice5, self).__init__() + self.relu = nn.ReLU() + self.slice = P.Slice() + self.begin = begin + self.size = size + + def construct(self, x): + return self.slice(x, self.begin, self.size) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_slice5(): + inputx = np.arange(3 * 5 * 4).reshape(3, 5, 4).astype(np.float32) + x = Tensor(inputx) + begin = (0, 1, 0) + size = (3, 4, 4) + slice_op = Slice5(begin, size) + output = slice_op(x) + assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all() + + if __name__ == '__main__': test_slice() test_slice2() test_slice3() test_slice4() + test_slice5()