forked from mindspore-Ecosystem/mindspore
support bool data type for StridedSlice op
This commit is contained in:
parent
6873bcf5eb
commit
42aa96a1a2
|
@ -23,6 +23,7 @@ constexpr int MAX_DIMS = 8;
|
|||
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
|
@ -94,8 +95,25 @@ void SliceCPUKernel::ExpandAllMemberDims() {
|
|||
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
bool ret{true};
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
ret = LaunchKernel<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeBool) {
|
||||
ret = LaunchKernel<bool>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
|
||||
return false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SliceCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
||||
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
|
||||
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
||||
|
@ -108,7 +126,7 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int i = begin_[0]; signstride[0] * i < signstride[0] * end_[0];
|
||||
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
|
||||
if (can_copy_memory[0]) {
|
||||
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||
continue;
|
||||
}
|
||||
auto in_c_offset = in_start_offset[1];
|
||||
|
@ -116,8 +134,8 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int j = begin_[1]; signstride[1] * j < signstride[1] * end_[1];
|
||||
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
|
||||
if (can_copy_memory[1]) {
|
||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, input_element_num_[1],
|
||||
1);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
|
||||
input_element_num_[1], 1);
|
||||
continue;
|
||||
}
|
||||
auto in_h_offset = in_start_offset[2];
|
||||
|
@ -125,8 +143,8 @@ bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int k = begin_[2]; signstride[2] * k < signstride[2] * end_[2];
|
||||
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
|
||||
if (can_copy_memory[2]) {
|
||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||
continue;
|
||||
}
|
||||
for (int m = begin_[3]; signstride[3] * m < signstride[3] * end_[3]; m += strides_[3]) {
|
||||
|
@ -154,23 +172,25 @@ int SliceCPUKernel::SignOfStride(size_t axis) const {
|
|||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
|
||||
size_t copy_num, int id) const {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in_buff_size = inputs[0]->size;
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto out_buff_size = outputs[0]->size;
|
||||
|
||||
if ((in_offset + copy_num) * sizeof(float) > in_buff_size) {
|
||||
if ((in_offset + copy_num) * sizeof(T) > in_buff_size) {
|
||||
MS_LOG(EXCEPTION) << "input memory out of bounds.";
|
||||
}
|
||||
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
|
||||
if ((out_offset + copy_num) * sizeof(T) > out_buff_size) {
|
||||
MS_LOG(EXCEPTION) << id << " output memory out of bounds.";
|
||||
}
|
||||
|
||||
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
|
||||
copy_num * sizeof(float));
|
||||
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(T), input_addr + in_offset,
|
||||
copy_num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
|
||||
}
|
||||
|
|
|
@ -33,12 +33,15 @@ class SliceCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void ExpandAllMemberDims();
|
||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||
int SignOfStride(size_t axis) const;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
|
||||
int id) const;
|
||||
void ExpandAllMemberDims();
|
||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||
int SignOfStride(size_t axis) const;
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
void TransArg();
|
||||
void ClipBegin();
|
||||
|
@ -49,6 +52,7 @@ class SliceCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> input_element_num_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -58,6 +62,8 @@ MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).Ad
|
|||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SliceCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ namespace kernel {
|
|||
void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
|
@ -78,8 +79,25 @@ void SliceGradCPUKernel::ExpandAllMemberDims() {
|
|||
bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
bool ret{true};
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
ret = LaunchKernel<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeBool) {
|
||||
ret = LaunchKernel<bool>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Slice op only support input_x int32 and float32";
|
||||
return false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SliceGradCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size);
|
||||
if (ret != EOK) {
|
||||
|
@ -97,7 +115,7 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int i = begin_[0]; stride_signs[0] * i < stride_signs[0] * end_[0];
|
||||
i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) {
|
||||
if (can_copy_memory[0]) {
|
||||
CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||
continue;
|
||||
}
|
||||
auto in_c_offset = 0;
|
||||
|
@ -105,8 +123,8 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int j = begin_[1]; stride_signs[1] * j < stride_signs[1] * end_[1];
|
||||
j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) {
|
||||
if (can_copy_memory[1]) {
|
||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, input_element_num_[1],
|
||||
1);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
|
||||
input_element_num_[1], 1);
|
||||
continue;
|
||||
}
|
||||
auto in_h_offset = 0;
|
||||
|
@ -114,8 +132,8 @@ bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (int k = begin_[2]; stride_signs[2] * k < stride_signs[2] * end_[2];
|
||||
k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) {
|
||||
if (can_copy_memory[2]) {
|
||||
CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||
continue;
|
||||
}
|
||||
for (int m = begin_[3]; stride_signs[3] * m < stride_signs[3] * end_[3]; m += strides_[3]) {
|
||||
|
@ -143,23 +161,24 @@ int SliceGradCPUKernel::SignOfStride(size_t axis) const {
|
|||
return -1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SliceGradCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
|
||||
size_t copy_num, int id) const {
|
||||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in_buff_size = inputs[0]->size;
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto out_buff_size = outputs[0]->size;
|
||||
|
||||
if ((in_offset + copy_num) * sizeof(float) > in_buff_size) {
|
||||
if ((in_offset + copy_num) * sizeof(T) > in_buff_size) {
|
||||
MS_LOG(EXCEPTION) << id << "input memory out of bounds.";
|
||||
}
|
||||
if ((out_offset + copy_num) * sizeof(float) > out_buff_size) {
|
||||
if ((out_offset + copy_num) * sizeof(T) > out_buff_size) {
|
||||
MS_LOG(EXCEPTION) << id << "output memory out of bounds.";
|
||||
}
|
||||
|
||||
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset,
|
||||
copy_num * sizeof(float));
|
||||
auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(T), input_addr + in_offset,
|
||||
copy_num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
|
||||
}
|
||||
|
|
|
@ -33,12 +33,16 @@ class SliceGradCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void ExpandAllMemberDims();
|
||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||
int SignOfStride(size_t axis) const;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
|
||||
int id) const;
|
||||
void ExpandAllMemberDims();
|
||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||
int SignOfStride(size_t axis) const;
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
void FormatArgs(bool stride);
|
||||
std::vector<int> begin_;
|
||||
|
@ -49,6 +53,7 @@ class SliceGradCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> input_element_num_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
|
@ -57,6 +62,10 @@ MS_REG_CPU_KERNEL(
|
|||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SliceGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -77,7 +78,52 @@ def test_slice_grad2():
|
|||
[[0., 0.], [8., 9.], [10., 11.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
class StridedSliceGrad(nn.Cell):
|
||||
def __init__(self, x, begin, end, stride):
|
||||
super(StridedSliceGrad, self).__init__()
|
||||
self.shape_op = P.Shape()
|
||||
self.shapex = self.shape_op(x)
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.stride = stride
|
||||
self.stride_slice = G.StridedSliceGrad()
|
||||
|
||||
def construct(self, dy):
|
||||
return self.stride_slice(dy, self.shapex, self.begin, self.end, self.stride)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_strided_slice_grad_bool_type():
|
||||
x = Tensor([[[False, False, True], [False, True, False]], [[False, True, False], [True, False, False]],
|
||||
[[False, True, True], [True, False, True]]], mstype.bool_)
|
||||
dy = Tensor([False, True, False], mstype.bool_)
|
||||
begin = (1, 0, 0)
|
||||
end = (2, 1, 3)
|
||||
stride = (1, 1, 1)
|
||||
slice_op = StridedSliceGrad(x, begin, end, stride)
|
||||
output = slice_op(dy)
|
||||
expected_output = np.array([[[False, False, False], [False, False, False]],
|
||||
[[False, True, False], [False, False, False]],
|
||||
[[False, False, False], [False, False, False]]])
|
||||
assert (output.asnumpy() == expected_output).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_strided_slice_grad_float32_type():
|
||||
x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]], mstype.float32)
|
||||
dy = Tensor([3, 3, 3], mstype.float32)
|
||||
begin = (1, 0, 0)
|
||||
end = (2, 1, 3)
|
||||
stride = (1, 1, 1)
|
||||
slice_op = StridedSliceGrad(x, begin, end, stride)
|
||||
output = slice_op(dy)
|
||||
expected_output = np.array([[[0, 0, 0], [0, 0, 0]], [[3, 3, 3], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]])
|
||||
assert (output.asnumpy() == expected_output).all()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice_grad()
|
||||
test_slice_grad2()
|
||||
test_strided_slice_grad_bool_type()
|
||||
test_strided_slice_grad_float32_type()
|
||||
|
|
|
@ -160,6 +160,31 @@ def test_slice6():
|
|||
assert (output[2].asnumpy() == inputx[-10:10:1, :, -10:10:1]).all()
|
||||
|
||||
|
||||
class StridedSlice(nn.Cell):
|
||||
def __init__(self, begin, end, stride):
|
||||
super(StridedSlice, self).__init__()
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.stride = stride
|
||||
self.stride_slice = P.StridedSlice()
|
||||
|
||||
def construct(self, x):
|
||||
return self.stride_slice(x, self.begin, self.end, self.stride)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_strided_slice_bool_type():
|
||||
input_x = Tensor([[[False, False, True], [False, True, False]], [[False, True, False], [True, False, False]],
|
||||
[[False, True, True], [True, False, True]]], mstype.bool_)
|
||||
begin = (1, 0, 0)
|
||||
end = (2, 1, 3)
|
||||
stride = (1, 1, 1)
|
||||
slice_op = StridedSlice(begin, end, stride)
|
||||
output = slice_op(input_x)
|
||||
expected_output = np.array([False, True, False])
|
||||
assert (output.asnumpy() == expected_output).all()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice()
|
||||
test_slice2()
|
||||
|
@ -167,3 +192,4 @@ if __name__ == '__main__':
|
|||
test_slice4()
|
||||
test_slice5()
|
||||
test_slice6()
|
||||
test_strided_slice_bool_type()
|
||||
|
|
Loading…
Reference in New Issue