!7227 fix cpu slice with a certain scene

Merge pull request !7227 from baihuawei/fixslice
This commit is contained in:
mindspore-ci-bot 2020-10-15 09:11:34 +08:00 committed by Gitee
commit d4d3d286cb
3 changed files with 42 additions and 22 deletions

View File

@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES); strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END); end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
TransArg(); TransArg();
for (size_t i = 0; i < begin_.size(); i++) { ClipBegin();
while (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = input_shape_[i];
}
}
} else { } else {
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE); auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
} }
for (size_t i = 0; i < begin_.size(); i++) { ClipBegin();
while (begin_[i] < 0) {
begin_[i] = begin_[i] + input_shape_[i];
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = input_shape_[i];
}
}
for (size_t i = 0; i < sizes.size(); ++i) { for (size_t i = 0; i < sizes.size(); ++i) {
while (sizes[i] < 0) { while (sizes[i] < 0) {
sizes[i] = sizes[i] + input_shape_[i]; sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
} }
strides_.emplace_back(1); strides_.emplace_back(1);
end_.emplace_back(begin_[i] + sizes[i]); end_.emplace_back(begin_[i] + sizes[i]);
@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
} }
void SliceCPUKernel::ClipBegin() {
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
auto k = begin_[i] + SizeToInt(input_shape_[i]);
begin_[i] = k < 0 ? 0 : k;
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = SizeToInt(input_shape_[i]);
}
}
}
void SliceCPUKernel::ExpandAllMemberDims() { void SliceCPUKernel::ExpandAllMemberDims() {
auto input_len = input_shape_.size(); auto input_len = input_shape_.size();
if (input_len < 4) { if (input_len < 4) {
@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() {
MS_LOG(EXCEPTION) << "slice stride cannot be zero"; MS_LOG(EXCEPTION) << "slice stride cannot be zero";
} }
if (end_[i] == 0 && begin_[i] < 0) { if (end_[i] == 0 && begin_[i] < 0) {
end_[i] = end_[i] + input_shape_[i]; end_[i] = end_[i] + SizeToInt(input_shape_[i]);
} }
while (end_[i] < 0) { if (end_[i] < 0) {
end_[i] = end_[i] + input_shape_[i]; end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
} }
if (end_[i] > SizeToInt(input_shape_[i])) { if (end_[i] > SizeToInt(input_shape_[i])) {
end_[i] = input_shape_[i]; end_[i] = SizeToInt(input_shape_[i]);
} }
} }
} }

View File

@ -41,6 +41,7 @@ class SliceCPUKernel : public CPUKernel {
int id) const; int id) const;
void CheckParam(const CNodePtr &kernel_node) const; void CheckParam(const CNodePtr &kernel_node) const;
void TransArg(); void TransArg();
void ClipBegin();
std::vector<int> begin_; std::vector<int> begin_;
std::vector<int> end_; std::vector<int> end_;
std::vector<int> strides_; std::vector<int> strides_;

View File

@ -138,9 +138,32 @@ def test_slice5():
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all() assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
class Slice6(nn.Cell):
def __init__(self):
super(Slice6, self).__init__()
self.relu = nn.ReLU()
def construct(self, x):
return (x[-10:], x[-5:10:2, :, :], x[-10:10:1, :, -10:10:1])
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice6():
inputx = np.random.rand(4, 4, 4).astype(np.float32)
x = Tensor(inputx)
slice_op = Slice6()
output = slice_op(x)
assert (output[0].asnumpy() == inputx[-10:]).all()
assert (output[1].asnumpy() == inputx[-5:10:2, :, :]).all()
assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all()
if __name__ == '__main__': if __name__ == '__main__':
test_slice() test_slice()
test_slice2() test_slice2()
test_slice3() test_slice3()
test_slice4() test_slice4()
test_slice5() test_slice5()
test_slice6()