forked from mindspore-Ecosystem/mindspore
fix scpu slice
This commit is contained in:
parent
39bc43e674
commit
e7928b9c0e
|
@ -29,30 +29,16 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
||||
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
||||
TransArg();
|
||||
for (size_t i = 0; i < begin_.size(); i++) {
|
||||
while (begin_[i] < 0) {
|
||||
begin_[i] = begin_[i] + input_shape_[i];
|
||||
}
|
||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||
begin_[i] = input_shape_[i];
|
||||
}
|
||||
}
|
||||
ClipBegin();
|
||||
} else {
|
||||
auto sizes = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, SIZE);
|
||||
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
||||
}
|
||||
for (size_t i = 0; i < begin_.size(); i++) {
|
||||
while (begin_[i] < 0) {
|
||||
begin_[i] = begin_[i] + input_shape_[i];
|
||||
}
|
||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||
begin_[i] = input_shape_[i];
|
||||
}
|
||||
}
|
||||
ClipBegin();
|
||||
for (size_t i = 0; i < sizes.size(); ++i) {
|
||||
while (sizes[i] < 0) {
|
||||
sizes[i] = sizes[i] + input_shape_[i];
|
||||
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
strides_.emplace_back(1);
|
||||
end_.emplace_back(begin_[i] + sizes[i]);
|
||||
|
@ -62,7 +48,17 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
}
|
||||
|
||||
void SliceCPUKernel::ClipBegin() {
|
||||
for (size_t i = 0; i < begin_.size(); i++) {
|
||||
if (begin_[i] < 0) {
|
||||
auto k = begin_[i] + SizeToInt(input_shape_[i]);
|
||||
begin_[i] = k < 0 ? 0 : k;
|
||||
}
|
||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||
begin_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
void SliceCPUKernel::ExpandAllMemberDims() {
|
||||
auto input_len = input_shape_.size();
|
||||
if (input_len < 4) {
|
||||
|
@ -178,13 +174,13 @@ void SliceCPUKernel::TransArg() {
|
|||
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
||||
}
|
||||
if (end_[i] == 0 && begin_[i] < 0) {
|
||||
end_[i] = end_[i] + input_shape_[i];
|
||||
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
while (end_[i] < 0) {
|
||||
end_[i] = end_[i] + input_shape_[i];
|
||||
if (end_[i] < 0) {
|
||||
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
if (end_[i] > SizeToInt(input_shape_[i])) {
|
||||
end_[i] = input_shape_[i];
|
||||
end_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ class SliceCPUKernel : public CPUKernel {
|
|||
int id) const;
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
void TransArg();
|
||||
void ClipBegin();
|
||||
std::vector<int> begin_;
|
||||
std::vector<int> end_;
|
||||
std::vector<int> strides_;
|
||||
|
|
|
@ -138,9 +138,32 @@ def test_slice5():
|
|||
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
|
||||
|
||||
|
||||
class Slice6(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Slice6, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
return (x[-10:], x[-5:10:2, :, :], x[-10:10:1, :, -10:10:1])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice6():
|
||||
inputx = np.random.rand(4, 4, 4).astype(np.float32)
|
||||
x = Tensor(inputx)
|
||||
slice_op = Slice6()
|
||||
output = slice_op(x)
|
||||
assert (output[0].asnumpy() == inputx[-10:]).all()
|
||||
assert (output[1].asnumpy() == inputx[-5:10:2, :, :]).all()
|
||||
assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice()
|
||||
test_slice2()
|
||||
test_slice3()
|
||||
test_slice4()
|
||||
test_slice5()
|
||||
test_slice6()
|
||||
|
|
Loading…
Reference in New Issue