!7123 fix cpu slice op

Merge pull request !7123 from baihuawei/fixslice
This commit is contained in:
mindspore-ci-bot 2020-10-10 13:51:35 +08:00 committed by Gitee
commit f5e91a544d
2 changed files with 35 additions and 10 deletions

View File

@ -21,7 +21,6 @@ namespace kernel {
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
@ -65,12 +64,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
void SliceCPUKernel::ExpandAllMemberDims() {
auto output_len = output_shape_.size();
if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) {
output_shape_.push_back(1);
}
}
auto input_len = input_shape_.size();
if (input_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) {
@ -80,6 +73,15 @@ void SliceCPUKernel::ExpandAllMemberDims() {
end_.insert(end_.begin(), 1);
}
}
for (size_t i = 0; i < 4; ++i) {
if (SignOfStride(i)) {
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
if (ax < 0) {
ax = 0;
}
output_shape_.push_back(IntToSize(ax));
}
}
}
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@ -87,7 +89,6 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],

View File

@ -45,7 +45,6 @@ def test_slice():
slice_op = Slice()
output = slice_op(x)
print("output:\n", output)
assert (output.asnumpy() == expect).all()
@ -68,7 +67,6 @@ def test_slice2():
slice_op = Slice2()
output = slice_op(x)
print("output:\n", output)
assert (output.asnumpy() == expect).all()
@ -115,8 +113,34 @@ def test_slice4():
assert (output.asnumpy() == inputx[:10:1, :, 2:3:1]).all()
class Slice5(nn.Cell):
def __init__(self, begin, size):
super(Slice5, self).__init__()
self.relu = nn.ReLU()
self.slice = P.Slice()
self.begin = begin
self.size = size
def construct(self, x):
return self.slice(x, self.begin, self.size)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice5():
inputx = np.arange(3 * 5 * 4).reshape(3, 5, 4).astype(np.float32)
x = Tensor(inputx)
begin = (0, 1, 0)
size = (3, 4, 4)
slice_op = Slice5(begin, size)
output = slice_op(x)
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
if __name__ == '__main__':
test_slice()
test_slice2()
test_slice3()
test_slice4()
test_slice5()