forked from mindspore-Ecosystem/mindspore
!7123 fix cpu slice op
Merge pull request !7123 from baihuawei/fixslice
This commit is contained in:
commit
f5e91a544d
|
@ -21,7 +21,6 @@ namespace kernel {
|
|||
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -65,12 +64,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
void SliceCPUKernel::ExpandAllMemberDims() {
|
||||
auto output_len = output_shape_.size();
|
||||
if (output_len < 4) {
|
||||
for (size_t i = 0; i < 4 - output_len; ++i) {
|
||||
output_shape_.push_back(1);
|
||||
}
|
||||
}
|
||||
auto input_len = input_shape_.size();
|
||||
if (input_len < 4) {
|
||||
for (size_t i = 0; i < 4 - input_len; ++i) {
|
||||
|
@ -80,6 +73,15 @@ void SliceCPUKernel::ExpandAllMemberDims() {
|
|||
end_.insert(end_.begin(), 1);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
if (SignOfStride(i)) {
|
||||
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
|
||||
if (ax < 0) {
|
||||
ax = 0;
|
||||
}
|
||||
output_shape_.push_back(IntToSize(ax));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
|
@ -87,7 +89,6 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
|
||||
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
||||
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
|
||||
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
||||
|
|
|
@ -45,7 +45,6 @@ def test_slice():
|
|||
|
||||
slice_op = Slice()
|
||||
output = slice_op(x)
|
||||
print("output:\n", output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
|
@ -68,7 +67,6 @@ def test_slice2():
|
|||
|
||||
slice_op = Slice2()
|
||||
output = slice_op(x)
|
||||
print("output:\n", output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
|
@ -115,8 +113,34 @@ def test_slice4():
|
|||
assert (output.asnumpy() == inputx[:10:1, :, 2:3:1]).all()
|
||||
|
||||
|
||||
class Slice5(nn.Cell):
|
||||
def __init__(self, begin, size):
|
||||
super(Slice5, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.slice = P.Slice()
|
||||
self.begin = begin
|
||||
self.size = size
|
||||
|
||||
def construct(self, x):
|
||||
return self.slice(x, self.begin, self.size)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice5():
|
||||
inputx = np.arange(3 * 5 * 4).reshape(3, 5, 4).astype(np.float32)
|
||||
x = Tensor(inputx)
|
||||
begin = (0, 1, 0)
|
||||
size = (3, 4, 4)
|
||||
slice_op = Slice5(begin, size)
|
||||
output = slice_op(x)
|
||||
assert (output.asnumpy() == inputx[0:3:1, 1:5:1, 0:4:1]).all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice()
|
||||
test_slice2()
|
||||
test_slice3()
|
||||
test_slice4()
|
||||
test_slice5()
|
||||
|
|
Loading…
Reference in New Issue