forked from mindspore-Ecosystem/mindspore
!6912 generalize CPU Slice op
Merge pull request !6912 from baihuawei/fixslice
This commit is contained in:
commit
d8cfba8ba6
|
@ -22,28 +22,20 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
CheckParam(kernel_node);
|
CheckParam(kernel_node);
|
||||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||||
|
|
||||||
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
|
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
|
||||||
for (size_t i = 0; i < begin_.size(); i++) {
|
|
||||||
if (begin_[i] < 0) {
|
|
||||||
begin_[i] = begin_[i] + input_shape_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto strides = prim->GetAttr(STRIDES);
|
auto strides = prim->GetAttr(STRIDES);
|
||||||
if (strides != nullptr) {
|
if (strides != nullptr) {
|
||||||
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
strides_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, STRIDES);
|
||||||
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
end_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, END);
|
||||||
if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) {
|
TransArg();
|
||||||
MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
|
for (size_t i = 0; i < begin_.size(); i++) {
|
||||||
|
while (begin_[i] < 0) {
|
||||||
|
begin_[i] = begin_[i] + input_shape_[i];
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < strides_.size(); ++i) {
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||||
if (strides_[i] < 0) {
|
begin_[i] = input_shape_[i];
|
||||||
strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0;
|
|
||||||
}
|
|
||||||
if (end_[i] < 0) {
|
|
||||||
end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -51,23 +43,34 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
||||||
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
||||||
}
|
}
|
||||||
|
for (size_t i = 0; i < begin_.size(); i++) {
|
||||||
|
while (begin_[i] < 0) {
|
||||||
|
begin_[i] = begin_[i] + input_shape_[i];
|
||||||
|
}
|
||||||
|
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||||
|
begin_[i] = input_shape_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
for (size_t i = 0; i < sizes.size(); ++i) {
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
||||||
if (sizes[i] < 0) {
|
while (sizes[i] < 0) {
|
||||||
sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0;
|
sizes[i] = sizes[i] + input_shape_[i];
|
||||||
}
|
}
|
||||||
strides_.emplace_back(1);
|
strides_.emplace_back(1);
|
||||||
end_.emplace_back(begin_[i] + sizes[i]);
|
end_.emplace_back(begin_[i] + sizes[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ExpandAllMemberDims();
|
ExpandAllMemberDims();
|
||||||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SliceCPUKernel::ExpandAllMemberDims() {
|
void SliceCPUKernel::ExpandAllMemberDims() {
|
||||||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
auto output_len = output_shape_.size();
|
||||||
|
if (output_len < 4) {
|
||||||
|
for (size_t i = 0; i < 4 - output_len; ++i) {
|
||||||
|
output_shape_.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
auto input_len = input_shape_.size();
|
auto input_len = input_shape_.size();
|
||||||
if (input_len < 4) {
|
if (input_len < 4) {
|
||||||
for (size_t i = 0; i < 4 - input_len; ++i) {
|
for (size_t i = 0; i < 4 - input_len; ++i) {
|
||||||
|
@ -86,6 +89,7 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
|
|
||||||
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
||||||
|
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
|
||||||
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
||||||
begin_[2] * input_element_num_[2]};
|
begin_[2] * input_element_num_[2]};
|
||||||
size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1],
|
size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1],
|
||||||
|
@ -93,31 +97,31 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
|
||||||
auto in_n_offset = in_start_offset[0];
|
auto in_n_offset = in_start_offset[0];
|
||||||
auto out_n_offset = 0;
|
auto out_n_offset = 0;
|
||||||
for (int i = begin_[0]; i < end_[0];
|
for (int i = begin_[0]; signstride[0] * i < signstride[0] * end_[0];
|
||||||
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
|
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
|
||||||
if (can_copy_memory[0]) {
|
if (can_copy_memory[0]) {
|
||||||
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]);
|
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto in_c_offset = in_start_offset[1];
|
auto in_c_offset = in_start_offset[1];
|
||||||
auto out_c_offset = 0;
|
auto out_c_offset = 0;
|
||||||
for (int j = begin_[1]; j < end_[1];
|
for (int j = begin_[1]; signstride[1] * j < signstride[1] * end_[1];
|
||||||
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
|
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
|
||||||
if (can_copy_memory[1]) {
|
if (can_copy_memory[1]) {
|
||||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
|
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, input_element_num_[1],
|
||||||
input_element_num_[1]);
|
1);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto in_h_offset = in_start_offset[2];
|
auto in_h_offset = in_start_offset[2];
|
||||||
auto out_h_offset = 0;
|
auto out_h_offset = 0;
|
||||||
for (int k = begin_[2]; k < end_[2];
|
for (int k = begin_[2]; signstride[2] * k < signstride[2] * end_[2];
|
||||||
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
|
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
|
||||||
if (can_copy_memory[2]) {
|
if (can_copy_memory[2]) {
|
||||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]);
|
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (int m = begin_[3]; m < end_[3]; m += strides_[3]) {
|
for (int m = begin_[3]; signstride[3] * m < signstride[3] * end_[3]; m += strides_[3]) {
|
||||||
*output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m];
|
*output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,9 +140,15 @@ bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int SliceCPUKernel::SignOfStride(size_t axis) const {
|
||||||
|
if (strides_[axis] > 0) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
|
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
|
||||||
size_t copy_num) const {
|
size_t copy_num, int id) const {
|
||||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
auto in_buff_size = inputs[0]->size;
|
auto in_buff_size = inputs[0]->size;
|
||||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
|
@ -148,7 +158,7 @@ void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inp
|
||||||
MS_LOG(EXCEPTION) << "input memory out of bounds.";
|
MS_LOG(EXCEPTION) << "input memory out of bounds.";
|
||||||
}
|
}
|
||||||
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
|
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
|
||||||
MS_LOG(EXCEPTION) << "output memory out of bounds.";
|
MS_LOG(EXCEPTION) << id << " output memory out of bounds.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
|
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
|
||||||
|
@ -158,6 +168,26 @@ void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SliceCPUKernel::TransArg() {
|
||||||
|
if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < strides_.size(); ++i) {
|
||||||
|
if (strides_[i] == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
||||||
|
}
|
||||||
|
if (end_[i] == 0 && begin_[i] < 0) {
|
||||||
|
end_[i] = end_[i] + input_shape_[i];
|
||||||
|
}
|
||||||
|
while (end_[i] < 0) {
|
||||||
|
end_[i] = end_[i] + input_shape_[i];
|
||||||
|
}
|
||||||
|
if (end_[i] > SizeToInt(input_shape_[i])) {
|
||||||
|
end_[i] = input_shape_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
|
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != 1) {
|
if (input_num != 1) {
|
||||||
|
|
|
@ -35,9 +35,12 @@ class SliceCPUKernel : public CPUKernel {
|
||||||
private:
|
private:
|
||||||
void ExpandAllMemberDims();
|
void ExpandAllMemberDims();
|
||||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||||
|
int SignOfStride(size_t axis) const;
|
||||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
|
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
|
||||||
|
int id) const;
|
||||||
void CheckParam(const CNodePtr &kernel_node) const;
|
void CheckParam(const CNodePtr &kernel_node) const;
|
||||||
|
void TransArg();
|
||||||
std::vector<int> begin_;
|
std::vector<int> begin_;
|
||||||
std::vector<int> end_;
|
std::vector<int> end_;
|
||||||
std::vector<int> strides_;
|
std::vector<int> strides_;
|
||||||
|
|
|
@ -72,6 +72,51 @@ def test_slice2():
|
||||||
assert (output.asnumpy() == expect).all()
|
assert (output.asnumpy() == expect).all()
|
||||||
|
|
||||||
|
|
||||||
|
class Slice3(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Slice3, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return (x[..., -1], x[..., 2:1:-1], x[1:3:1, 0, ...], x[-1, 0, ...])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_slice3():
|
||||||
|
inputx = np.random.rand(4, 4, 4, 4).astype(np.float32)
|
||||||
|
x = Tensor(inputx)
|
||||||
|
slice_op = Slice3()
|
||||||
|
output = slice_op(x)
|
||||||
|
assert (output[0].asnumpy() == inputx[..., -1]).all()
|
||||||
|
assert (output[1].asnumpy() == inputx[..., 2:1:-1]).all()
|
||||||
|
assert (output[2].asnumpy() == inputx[1:3:1, 0, ...]).all()
|
||||||
|
assert (output[3].asnumpy() == inputx[-1, 0, ...]).all()
|
||||||
|
|
||||||
|
|
||||||
|
class Slice4(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Slice4, self).__init__()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x[:10:1, :, 2:3:1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_slice4():
|
||||||
|
inputx = np.random.rand(4, 4, 4).astype(np.float32)
|
||||||
|
x = Tensor(inputx)
|
||||||
|
slice_op = Slice4()
|
||||||
|
output = slice_op(x)
|
||||||
|
assert (output.asnumpy() == inputx[:10:1, :, 2:3:1]).all()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_slice()
|
test_slice()
|
||||||
test_slice2()
|
test_slice2()
|
||||||
|
test_slice3()
|
||||||
|
test_slice4()
|
||||||
|
|
Loading…
Reference in New Issue