!7227 fix cpu slice with a certain scene

Merge pull request !7227 from baihuawei/fixslice
This commit is contained in:
mindspore-ci-bot 2020-10-15 09:11:34 +08:00 committed by Gitee
commit d4d3d286cb
3 changed files with 42 additions and 22 deletions

View File

@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
TransArg();
for (size_t i = 0; i < begin_.size(); i++) {
while (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = input_shape_[i];
}
}
ClipBegin();
} else {
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
}
for (size_t i = 0; i < begin_.size(); i++) {
while (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = input_shape_[i];
}
}
ClipBegin();
for (size_t i = 0; i < sizes.size(); ++i) {
while (sizes[i] < 0) {
sizes[i] = sizes[i] + input_shape_[i];
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
}
strides_.emplace_back(1);
end_.emplace_back(begin_[i] + sizes[i]);
@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceCPUKernel::ClipBegin() {
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
auto k = begin_[i] + SizeToInt(input_shape_[i]);
begin_[i] = k < 0 ? 0 : k;
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = SizeToInt(input_shape_[i]);
}
}
}
void SliceCPUKernel::ExpandAllMemberDims() {
auto input_len = input_shape_.size();
if (input_len < 4) {
@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() {
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
}
if (end_[i] == 0 && begin_[i] < 0) {
end_[i] = end_[i] + input_shape_[i];
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
}
while (end_[i] < 0) {
end_[i] = end_[i] + input_shape_[i];
if (end_[i] < 0) {
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
}
if (end_[i] > SizeToInt(input_shape_[i])) {
end_[i] = input_shape_[i];
end_[i] = SizeToInt(input_shape_[i]);
}
}
}

View File

@ -41,6 +41,7 @@ class SliceCPUKernel : public CPUKernel {
int id) const;
void CheckParam(const CNodePtr &kernel_node) const;
void TransArg();
void ClipBegin();
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;

View File

@ -138,9 +138,32 @@ def test_slice5():
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
class Slice6(nn.Cell):
def __init__(self):
super(Slice6, self).__init__()
self.relu = nn.ReLU()
def construct(self, x):
return (x[-10:], x[-5:10:2, :, :], x[-10:10:1, :, -10:10:1])
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice6():
inputx = np.random.rand(4, 4, 4).astype(np.float32)
x = Tensor(inputx)
slice_op = Slice6()
output = slice_op(x)
assert (output[0].asnumpy() == inputx[-10:]).all()
assert (output[1].asnumpy() == inputx[-5:10:2, :, :]).all()
assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all()
if __name__ == '__main__':
test_slice()
test_slice2()
test_slice3()
test_slice4()
test_slice5()
test_slice6()