forked from mindspore-Ecosystem/mindspore
!7227 fix cpu slice with a certain scene
Merge pull request !7227 from baihuawei/fixslice
This commit is contained in:
commit
d4d3d286cb
|
@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
||||||
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
||||||
TransArg();
|
TransArg();
|
||||||
for (size_t i = 0; i < begin_.size(); i++) {
|
ClipBegin();
|
||||||
while (begin_[i] < 0) {
|
|
||||||
begin_[i] = begin_[i] + input_shape_[i];
|
|
||||||
}
|
|
||||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
|
||||||
begin_[i] = input_shape_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
|
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
|
||||||
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
||||||
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < begin_.size(); i++) {
|
ClipBegin();
|
||||||
while (begin_[i] < 0) {
|
|
||||||
begin_[i] = begin_[i] + input_shape_[i];
|
|
||||||
}
|
|
||||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
|
||||||
begin_[i] = input_shape_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < sizes.size(); ++i) {
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
||||||
while (sizes[i] < 0) {
|
while (sizes[i] < 0) {
|
||||||
sizes[i] = sizes[i] + input_shape_[i];
|
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
|
||||||
}
|
}
|
||||||
strides_.emplace_back(1);
|
strides_.emplace_back(1);
|
||||||
end_.emplace_back(begin_[i] + sizes[i]);
|
end_.emplace_back(begin_[i] + sizes[i]);
|
||||||
|
@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||||
}
|
}
|
||||||
|
void SliceCPUKernel::ClipBegin() {
|
||||||
|
for (size_t i = 0; i < begin_.size(); i++) {
|
||||||
|
if (begin_[i] < 0) {
|
||||||
|
auto k = begin_[i] + SizeToInt(input_shape_[i]);
|
||||||
|
begin_[i] = k < 0 ? 0 : k;
|
||||||
|
}
|
||||||
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||||
|
begin_[i] = SizeToInt(input_shape_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
void SliceCPUKernel::ExpandAllMemberDims() {
|
void SliceCPUKernel::ExpandAllMemberDims() {
|
||||||
auto input_len = input_shape_.size();
|
auto input_len = input_shape_.size();
|
||||||
if (input_len < 4) {
|
if (input_len < 4) {
|
||||||
|
@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() {
|
||||||
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
||||||
}
|
}
|
||||||
if (end_[i] == 0 && begin_[i] < 0) {
|
if (end_[i] == 0 && begin_[i] < 0) {
|
||||||
end_[i] = end_[i] + input_shape_[i];
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
|
||||||
}
|
}
|
||||||
while (end_[i] < 0) {
|
if (end_[i] < 0) {
|
||||||
end_[i] = end_[i] + input_shape_[i];
|
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
|
||||||
}
|
}
|
||||||
if (end_[i] > SizeToInt(input_shape_[i])) {
|
if (end_[i] > SizeToInt(input_shape_[i])) {
|
||||||
end_[i] = input_shape_[i];
|
end_[i] = SizeToInt(input_shape_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ class SliceCPUKernel : public CPUKernel {
|
||||||
int id) const;
|
int id) const;
|
||||||
void CheckParam(const CNodePtr &kernel_node) const;
|
void CheckParam(const CNodePtr &kernel_node) const;
|
||||||
void TransArg();
|
void TransArg();
|
||||||
|
void ClipBegin();
|
||||||
std::vector<int> begin_;
|
std::vector<int> begin_;
|
||||||
std::vector<int> end_;
|
std::vector<int> end_;
|
||||||
std::vector<int> strides_;
|
std::vector<int> strides_;
|
||||||
|
|
|
@ -138,9 +138,32 @@ def test_slice5():
|
||||||
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
|
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
|
||||||
|
|
||||||
|
|
||||||
|
class Slice6(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Slice6, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return (x[-10:], x[-5:10:2, :, :], x[-10:10:1, :, -10:10:1])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_slice6():
|
||||||
|
inputx = np.random.rand(4, 4, 4).astype(np.float32)
|
||||||
|
x = Tensor(inputx)
|
||||||
|
slice_op = Slice6()
|
||||||
|
output = slice_op(x)
|
||||||
|
assert (output[0].asnumpy() == inputx[-10:]).all()
|
||||||
|
assert (output[1].asnumpy() == inputx[-5:10:2, :, :]).all()
|
||||||
|
assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_slice()
|
test_slice()
|
||||||
test_slice2()
|
test_slice2()
|
||||||
test_slice3()
|
test_slice3()
|
||||||
test_slice4()
|
test_slice4()
|
||||||
test_slice5()
|
test_slice5()
|
||||||
|
test_slice6()
|
||||||
|
|
Loading…
Reference in New Issue