forked from mindspore-Ecosystem/mindspore
gpu stridedslice
This commit is contained in:
parent
9b9d7be73d
commit
837aecf9af
|
@ -59,6 +59,7 @@ class StridedSliceGpuKernel : public GpuKernel {
|
|||
ParseMasks(kernel_node);
|
||||
FillOutputDim();
|
||||
null_output_ = IsNullOutput();
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -86,14 +87,15 @@ class StridedSliceGpuKernel : public GpuKernel {
|
|||
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
if (i < begin_.size()) {
|
||||
begin_[i] =
|
||||
std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1));
|
||||
int dim = SizeToInt(input_shape_[i]);
|
||||
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
|
||||
} else {
|
||||
begin_.push_back(0);
|
||||
}
|
||||
|
||||
if (i < end_.size()) {
|
||||
end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1);
|
||||
int dim = SizeToInt(input_shape_[i]);
|
||||
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
|
||||
} else {
|
||||
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
|
||||
}
|
||||
|
|
|
@ -87,14 +87,15 @@ class StridedSliceGradGpuKernel : public GpuKernel {
|
|||
|
||||
for (size_t i = 0; i < MAX_DIMS; i++) {
|
||||
if (i < begin_.size()) {
|
||||
begin_[i] =
|
||||
std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1));
|
||||
int dim = SizeToInt(input_shape_[i]);
|
||||
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, 0) : begin_[i], dim - 1);
|
||||
} else {
|
||||
begin_.push_back(0);
|
||||
}
|
||||
|
||||
if (i < end_.size()) {
|
||||
end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1);
|
||||
int dim = SizeToInt(input_shape_[i]);
|
||||
end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), -1);
|
||||
} else {
|
||||
end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1);
|
||||
}
|
||||
|
|
|
@ -150,73 +150,6 @@ def strided_slice_grad(nptype):
|
|||
[0., 0., 0., 0., 0.]]]]).astype(nptype)
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
# ME infer fault
|
||||
# y = GradData()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
|
||||
# expect = np.array([[[[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]]],
|
||||
|
||||
|
||||
# [[[0., 0., 0., 0., 0.],
|
||||
# [0., 1., 0., 1., 0.],
|
||||
# [0., 1., 0., 1., 0.],
|
||||
# [0., 1., 0., 1., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 1., 0., 1., 0.],
|
||||
# [0., 1., 0., 1., 0.],
|
||||
# [0., 1., 0., 1., 0.]],begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]]]])
|
||||
# assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
# y = Grad(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
|
||||
# expect = np.array([[[[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]]],
|
||||
|
||||
|
||||
# [[[0., 0., 1., 1., 0.],
|
||||
# [0., 0., 1., 1., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 1., 1., 0.],
|
||||
# [0., 0., 1., 1., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]],
|
||||
|
||||
# [[0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.],
|
||||
# [0., 0., 0., 0., 0.]]]])
|
||||
# assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
|
||||
net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1),
|
||||
begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
|
||||
|
|
|
@ -45,23 +45,23 @@ def strided_slice(nptype):
|
|||
[89, 88, 87]]]]).astype(nptype)
|
||||
assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
# ME infer fault
|
||||
# y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
|
||||
# expect = np.array([[[[78, 76],
|
||||
# [73, 71],
|
||||
# [68, 66]],
|
||||
# [[98, 96],
|
||||
# [93, 91],
|
||||
# [88, 86]]]])
|
||||
# assert np.allclose(y.asnumpy(), expect)
|
||||
y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2))
|
||||
expect = np.array([[[[78, 76],
|
||||
[73, 71],
|
||||
[68, 66]],
|
||||
[[98, 96],
|
||||
[93, 91],
|
||||
[88, 86]]]]).astype(nptype)
|
||||
assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
# ME Infer fault
|
||||
# y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1))
|
||||
# expect = np.array([[[[ 62, 63],
|
||||
# [ 67, 68]],
|
||||
# [[ 82, 83],
|
||||
# [ 87, 88]],
|
||||
# expect = np.array([[[[62, 63],
|
||||
# [67, 68]],
|
||||
# [[82, 83],
|
||||
# [87, 88]],
|
||||
# [[102, 103],
|
||||
# [107, 108]]]])
|
||||
# [107, 108]]]]).astype(nptype)
|
||||
# assert np.allclose(y.asnumpy(), expect)
|
||||
|
||||
op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100)
|
||||
|
@ -125,3 +125,25 @@ def test_strided_slice_uint8():
|
|||
@pytest.mark.env_onecard
|
||||
def test_strided_slice_bool():
|
||||
strided_slice(np.bool)
|
||||
x = Tensor(np.arange(0, 4*4*4).reshape(4, 4, 4).astype(np.float32))
|
||||
y = x[-8:, :8]
|
||||
expect = np.array([[[0., 1., 2., 3.],
|
||||
[4., 5., 6., 7.],
|
||||
[8., 9., 10., 11.],
|
||||
[12., 13., 14., 15.]],
|
||||
|
||||
[[16., 17., 18., 19.],
|
||||
[20., 21., 22., 23.],
|
||||
[24., 25., 26., 27.],
|
||||
[28., 29., 30., 31.]],
|
||||
|
||||
[[32., 33., 34., 35.],
|
||||
[36., 37., 38., 39.],
|
||||
[40., 41., 42., 43.],
|
||||
[44., 45., 46., 47.]],
|
||||
|
||||
[[48., 49., 50., 51.],
|
||||
[52., 53., 54., 55.],
|
||||
[56., 57., 58., 59.],
|
||||
[60., 61., 62., 63.]]])
|
||||
assert np.allclose(y.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue