diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc index d2708832e9a..2354a7e76f1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc @@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); TransArg(); - for (size_t i = 0; i < begin_.size(); i++) { - while (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - if (begin_[i] > SizeToInt(input_shape_[i])) { - begin_[i] = input_shape_[i]; - } - } + ClipBegin(); } else { auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; } - for (size_t i = 0; i < begin_.size(); i++) { - while (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - if (begin_[i] > SizeToInt(input_shape_[i])) { - begin_[i] = input_shape_[i]; - } - } + ClipBegin(); for (size_t i = 0; i < sizes.size(); ++i) { while (sizes[i] < 0) { - sizes[i] = sizes[i] + input_shape_[i]; + sizes[i] = sizes[i] + SizeToInt(input_shape_[i]); } strides_.emplace_back(1); end_.emplace_back(begin_[i] + sizes[i]); @@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); } - +void SliceCPUKernel::ClipBegin() { + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + auto k = begin_[i] + SizeToInt(input_shape_[i]); + begin_[i] = k < 0 ? 0 : k; + } + if (begin_[i] > SizeToInt(input_shape_[i])) { + begin_[i] = SizeToInt(input_shape_[i]); + } + } +} void SliceCPUKernel::ExpandAllMemberDims() { auto input_len = input_shape_.size(); if (input_len < 4) { @@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() { MS_LOG(EXCEPTION) << "slice stride cannot be zero"; } if (end_[i] == 0 && begin_[i] < 0) { - end_[i] = end_[i] + input_shape_[i]; + end_[i] = end_[i] + SizeToInt(input_shape_[i]); } - while (end_[i] < 0) { - end_[i] = end_[i] + input_shape_[i]; + if (end_[i] < 0) { + end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]); } if (end_[i] > SizeToInt(input_shape_[i])) { - end_[i] = input_shape_[i]; + end_[i] = SizeToInt(input_shape_[i]); } } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h index 3afc0464fe4..f3d7e768a5f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h @@ -41,6 +41,7 @@ class SliceCPUKernel : public CPUKernel { int id) const; void CheckParam(const CNodePtr &kernel_node) const; void TransArg(); + void ClipBegin(); std::vector begin_; std::vector end_; std::vector strides_; diff --git a/tests/st/ops/cpu/test_slice_op.py b/tests/st/ops/cpu/test_slice_op.py index d13a5ec5555..0fba3a11b45 100644 --- a/tests/st/ops/cpu/test_slice_op.py +++ b/tests/st/ops/cpu/test_slice_op.py @@ -138,9 +138,32 @@ def test_slice5(): assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all() +class Slice6(nn.Cell): + def __init__(self): + super(Slice6, self).__init__() + self.relu = nn.ReLU() + + def construct(self, x): + return (x[-10:], x[-5:10:2, :, :], x[-10:10:1, :, -10:10:1]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_slice6(): + inputx = np.random.rand(4, 4, 4).astype(np.float32) + x = Tensor(inputx) + slice_op = Slice6() + output = slice_op(x) + assert (output[0].asnumpy() == inputx[-10:]).all() + assert (output[1].asnumpy() == inputx[-5:10:2, :, :]).all() + assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all() + + if __name__ == '__main__': test_slice() test_slice2() test_slice3() test_slice4() test_slice5() + test_slice6()