gpu stridedslice

This commit is contained in:
wilfChen 2020-08-21 16:21:27 +08:00
parent 9b9d7be73d
commit 837aecf9af
4 changed files with 45 additions and 87 deletions

View File

@ -59,6 +59,7 @@ class StridedSliceGpuKernel : public GpuKernel {
ParseMasks(kernel_node); ParseMasks(kernel_node);
FillOutputDim(); FillOutputDim();
null_output_ = IsNullOutput(); null_output_ = IsNullOutput();
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@ -86,14 +87,15 @@ class StridedSliceGpuKernel : public GpuKernel {
for (size_t i = 0; i < MAX_DIMS; i++) { for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) { if (i < begin_.size()) {
begin_[i] = int dim = SizeToInt(input_shape_[i]);
std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
} else { } else {
begin_.push_back(0); begin_.push_back(0);
} }
if (i < end_.size()) { if (i < end_.size()) {
end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); int dim = SizeToInt(input_shape_[i]);
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
} else { } else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
} }

View File

@ -87,14 +87,15 @@ class StridedSliceGradGpuKernel : public GpuKernel {
for (size_t i = 0; i < MAX_DIMS; i++) { for (size_t i = 0; i < MAX_DIMS; i++) {
if (i < begin_.size()) { if (i < begin_.size()) {
begin_[i] = int dim = SizeToInt(input_shape_[i]);
std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
} else { } else {
begin_.push_back(0); begin_.push_back(0);
} }
if (i < end_.size()) { if (i < end_.size()) {
end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); int dim = SizeToInt(input_shape_[i]);
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
} else { } else {
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
} }

View File

@ -150,73 +150,6 @@ def strided_slice_grad(nptype):
[0., 0., 0., 0., 0.]]]]).astype(nptype) [0., 0., 0., 0., 0.]]]]).astype(nptype)
assert np.allclose(dx[0].asnumpy(), expect) assert np.allclose(dx[0].asnumpy(), expect)
# ME infer fault
# y = GradData()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
# expect = np.array([[[[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]]],
# [[[0., 0., 0., 0., 0.],
# [0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0.]],begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]]]])
# assert np.allclose(y.asnumpy(), expect)
# y = Grad(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
# expect = np.array([[[[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]]],
# [[[0., 0., 1., 1., 0.],
# [0., 0., 1., 1., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 1., 1., 0.],
# [0., 0., 1., 1., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]],
# [[0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 0., 0.]]]])
# assert np.allclose(y.asnumpy(), expect)
net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1), net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1),
begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)

View File

@ -45,23 +45,23 @@ def strided_slice(nptype):
[89, 88, 87]]]]).astype(nptype) [89, 88, 87]]]]).astype(nptype)
assert np.allclose(y.asnumpy(), expect) assert np.allclose(y.asnumpy(), expect)
# ME infer fault y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
# y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2)) expect = np.array([[[[78, 76],
# expect = np.array([[[[78, 76], [73, 71],
# [73, 71], [68, 66]],
# [68, 66]], [[98, 96],
# [[98, 96], [93, 91],
# [93, 91], [88, 86]]]]).astype(nptype)
# [88, 86]]]]) assert np.allclose(y.asnumpy(), expect)
# assert np.allclose(y.asnumpy(), expect)
# ME Infer fault
# y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) # y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
# expect = np.array([[[[ 62, 63], # expect = np.array([[[[62, 63],
# [ 67, 68]], # [67, 68]],
# [[ 82, 83], # [[82, 83],
# [ 87, 88]], # [87, 88]],
# [[102, 103], # [[102, 103],
# [107, 108]]]]) # [107, 108]]]]).astype(nptype)
# assert np.allclose(y.asnumpy(), expect) # assert np.allclose(y.asnumpy(), expect)
op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
@ -125,3 +125,25 @@ def test_strided_slice_uint8():
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_strided_slice_bool(): def test_strided_slice_bool():
strided_slice(np.bool) strided_slice(np.bool)
x = Tensor(np.arange(0, 4*4*4).reshape(4, 4, 4).astype(np.float32))
y = x[-8:, :8]
expect = np.array([[[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
[12., 13., 14., 15.]],
[[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.]],
[[32., 33., 34., 35.],
[36., 37., 38., 39.],
[40., 41., 42., 43.],
[44., 45., 46., 47.]],
[[48., 49., 50., 51.],
[52., 53., 54., 55.],
[56., 57., 58., 59.],
[60., 61., 62., 63.]]])
assert np.allclose(y.asnumpy(), expect)