!2747 refactor StridedSlice op

Merge pull request !2747 from zhangbuxue/refactor_the_StridedSlice_op
This commit is contained in:
mindspore-ci-bot 2020-07-01 12:04:11 +08:00 committed by Gitee
commit 30fc1bd0c3
3 changed files with 236 additions and 83 deletions

View File

@ -37,6 +37,7 @@ from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import typing
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_value_type('axis', axis, [int, tuple], prim_name)
@ -193,7 +194,7 @@ class Cast(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
def check_elim(self, x, dtype):
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
@ -987,10 +988,10 @@ class InvertPermutation(PrimitiveWithInfer):
z.sort()
for i in range(1, len(z)):
if z[i-1] == z[i]:
if z[i - 1] == z[i]:
raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.")
validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name)
validator.check(f'value max', max(x_value), '', len(x_value) - 1, Rel.EQ, self.name)
y = [None] * len(x_value)
for i, value in enumerate(x_value):
@ -1693,6 +1694,57 @@ class Select(PrimitiveWithInfer):
return None
def _compute_slicing_length(begin, end, stride, x_shape, i):
"""Compute the length of the slicing."""
if i >= len(x_shape):
raise ValueError(f"For 'StridedSlice', When their is no new axis, the index length must be less or "
f"equal than the dim of x.")
x_dim = x_shape[i]
if stride > 0:
# When slicing forward, convert begin and end to positive numbers.
if begin >= x_dim or end < -x_dim:
# When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0.
slicing_length = 0
else:
if -x_dim <= begin < 0:
begin += x_dim
if begin < -x_dim:
# When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element.
begin = 0
if -x_dim <= end < 0:
end += x_dim
if end > x_dim:
# When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element.
end = x_dim
if begin >= end:
# When slicing forward, if begin >= end, the length of the slicing is 0.
slicing_length = 0
else:
slicing_length = 1 + (end - 1 - begin) // stride
else:
# When slicing backward, convert begin and end to negative numbers.
if begin < -x_dim or end >= x_dim:
# When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0.
slicing_length = 0
else:
if 0 <= begin < x_dim:
begin += -x_dim
if begin >= x_dim:
# When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element.
begin = -1
if 0 < end < x_dim:
end += -x_dim
if end < -x_dim - 1:
# When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
# slicing to the 0th element.
end = -x_dim - 1
if begin <= end:
# When slicing backward, if begin <= end, the length of the slicing is 0.
slicing_length = 0
else:
slicing_length = 1 + (end + 1 - begin) // stride
return slicing_length
class StridedSlice(PrimitiveWithInfer):
r"""
@ -1756,13 +1808,15 @@ class StridedSlice(PrimitiveWithInfer):
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0):
"""init StrideSlice"""
"""Init StrideSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
validator.check_integer('begin_mask', begin_mask, 0, Rel.GE, self.name)
validator.check_integer('end_mask', end_mask, 0, Rel.GE, self.name)
validator.check_integer('ellipsis_mask', ellipsis_mask, 0, Rel.GE, self.name)
if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1:
raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.")
validator.check_integer('new_axis_mask', new_axis_mask, 0, Rel.GE, self.name)
validator.check_integer('shrink_axis_mask', shrink_axis_mask, 0, Rel.GE, self.name)
def __infer__(self, x, begin, end, strides):
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
@ -1770,58 +1824,103 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_value_type("end", end_v, [tuple], self.name)
validator.check_value_type("strides", strides_v, [tuple], self.name)
x_shape = x['shape']
x_shp_len = len(x_shape)
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
ret_shape = []
append_dimensions = []
shrink_pos = bin(self.shrink_axis_mask)[::-1]
new_pos = bin(self.new_axis_mask)[::-1]
for i in range(x_shp_len):
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'
if i < (len(new_pos) - 2) and new_pos[i] == '1':
ret_shape.append(1)
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
continue
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
continue
if tuple(filter(lambda x: x == 0, strides_v)):
raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")
begin_idx = begin_v[i]
end_idx = end_v[i]
strides_idx = strides_v[i]
if self.begin_mask:
begin_idx = 0
if self.end_mask:
end_idx = x_shape[i]
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
if strides_idx > 0:
# If sliced forward , end_idx >= begin_idx
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
if begin_idx < 0 < end_idx:
# Turn negative begin_idx into positive values
begin_idx = x_shape[i] + begin_idx
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx
else:
# If sliced backwards, end_idx <= begin_idx
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE)
if end_idx < 0 < begin_idx:
# Turn negative end_idx into positive values
end_idx = x_shape[i] + end_idx
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx
if len(end_v) != len(begin_v) or len(strides_v) != len(begin_v):
raise ValueError(f"For '{self.name}' the length of begin index: {begin_v}, end index: {end_v} and "
f"strides: {strides_v} must be equal.")
ret_shape.append(num_elems)
if append_dimensions:
ret_shape += append_dimensions[::-1]
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type())
return {'shape': ret_shape,
'dtype': x['dtype'],
'value': None}
'value': value}
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
"""Compute the shape of the slicing."""
x_rank = len(x_shape)
slice_len = len(begin_v)
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'.
begin_pos = bin(self.begin_mask)[-1:1:-1]
end_pos = bin(self.end_mask)[-1:1:-1]
ellipsis_pos = bin(self.ellipsis_mask)[-1:1:-1]
new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
ret_shape = []
i, j = 0, 0
has_ellipsis = False
while i < x_rank or j < slice_len:
if j < slice_len:
begin, end, stride = begin_v[j], end_v[j], strides_v[j]
if j < len(ellipsis_pos) and ellipsis_pos[j] == '1':
# When there is ellipsis, the latter part of the ellipsis will be processed separately.
has_ellipsis = True
break
if j < len(begin_pos) and begin_pos[j] == '1':
begin = -1 if strides_v[j] < 0 else 0
if j < len(end_pos) and end_pos[j] == '1':
end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
if j < len(new_axis_pos) and new_axis_pos[j] == '1':
ret_shape.append(1)
j += 1
continue
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
f"but got stride: {stride}, begin: {begin}.")
j += 1
i += 1
continue
else:
begin, end, stride = 0, x_shape[i], 1
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
ret_shape.append(slicing_length)
i += 1
j += 1
if has_ellipsis:
# When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
j += 1
i += ellipsis_occupied_dims
while i < x_rank or j < slice_len:
begin, end, stride = begin_v[j], end_v[j], strides_v[j]
if j < len(begin_pos) and begin_pos[j] == '1':
begin = -1 if strides_v[j] < 0 else 0
if j < len(end_pos) and end_pos[j] == '1':
end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
if j < len(new_axis_pos) and new_axis_pos[j] == '1':
ret_shape.append(1)
j += 1
continue
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
f"but got stride: {stride}, begin: {begin}.")
j += 1
i += 1
continue
slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
ret_shape.append(slicing_length)
i += 1
j += 1
return ret_shape
class Diag(PrimitiveWithInfer):
@ -2102,6 +2201,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self):
"""Init TensorScatterUpdate"""
@ -2153,6 +2253,7 @@ class ScatterUpdate(PrimitiveWithInfer):
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterUpdate"""
@ -2201,6 +2302,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_locking=True):
"""Init ScatterNdUpdate"""
@ -2220,6 +2322,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
if updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
@ -2912,6 +3015,7 @@ class InplaceUpdate(PrimitiveWithInfer):
[ 4. 5.]
[ 6. 7.]]]
"""
@prim_attr_register
def __init__(self, indices):
"""Init InplaceUpdate"""

View File

@ -35,25 +35,6 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
def test_tensor_scatter_update():
class TensorScatterUpdateNet(nn.Cell):
"""TensorScatterUpdate net definition"""
def __init__(self):
super(TensorScatterUpdateNet, self).__init__()
self.tensor_scatter_update = P.TensorScatterUpdate()
def construct(self, x, i, u):
out = self.tensor_scatter_update(x, i, u)
return out
net = TensorScatterUpdateNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
updates = Tensor(np.ones([2, 5], np.float32))
net(x, indices, updates)
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()
@ -446,6 +427,7 @@ class SparseApplyAdagradNet(nn.Cell):
out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
@ -496,6 +478,60 @@ class NormalNet(nn.Cell):
return out
class StridedSliceNet(nn.Cell):
def __init__(self):
super(StridedSliceNet, self).__init__()
self.begins = (1, 2, 3, 2, 1)
self.ends = (5, 6, 7, 8, 9)
self.strides = (1, 2, 3, 2, 1)
self.strided_slice_0 = P.StridedSlice(begin_mask=3, end_mask=5, ellipsis_mask=4,
shrink_axis_mask=2, new_axis_mask=8)
self.strided_slice_1 = P.StridedSlice(begin_mask=5, end_mask=2, ellipsis_mask=2,
shrink_axis_mask=6, new_axis_mask=10)
self.strided_slice_2 = P.StridedSlice(begin_mask=3, end_mask=3, ellipsis_mask=4,
shrink_axis_mask=5, new_axis_mask=13)
self.strided_slice_3 = P.StridedSlice(begin_mask=0, end_mask=0, ellipsis_mask=4,
shrink_axis_mask=12, new_axis_mask=15)
self.const_0 = Tensor(np.ones([6, 8, 9, 1, 8], np.float32))
self.const_1 = Tensor(np.ones([5, 7, 8, 1, 8], np.float32))
self.const_2 = Tensor(np.ones([1, 3, 7, 8, 9, 1, 8], np.float32))
self.const_3 = Tensor(np.ones([1, 1, 6, 7, 8, 9, 1, 8], np.float32))
def construct(self, x):
out_0 = self.strided_slice_0(x, self.begins, self.ends, self.strides) + self.const_0
out_1 = self.strided_slice_1(x, self.begins, self.ends, self.strides) + self.const_1
out_2 = self.strided_slice_2(x, self.begins, self.ends, self.strides) + self.const_2
out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3
return out_0, out_1, out_2, out_3
def test_strided_slice_const():
class StridedSLiceConstNet(nn.Cell):
"""StridedSLiceConstNet net definition"""
def __init__(self):
super(StridedSLiceConstNet, self).__init__()
self.begins = (0, 2, -5, 2, 1)
self.ends = (0, 6, 9, 8, 9)
self.strides = (1, 2, 1, 2, 1)
self.strided_slice = P.StridedSlice(begin_mask=2,
end_mask=6,
ellipsis_mask=4,
shrink_axis_mask=6,
new_axis_mask=18)
def construct(self, x):
out = self.strided_slice(x, self.begins, self.ends, self.strides)
return out
net = StridedSLiceConstNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.ones([6, 7, 8, 9, 10]), mstype.float32)
ret = net(x)
assert ret.shape == (0, 1, 7, 8, 9, 3, 1)
assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all()
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@ -1366,6 +1402,10 @@ test_case_nn_ops = [
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'desc_bprop': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
'skip': ['backward']}),
('StridedSliceNet', {
'block': StridedSliceNet(),
'desc_inputs': [[6, 7, 8, 9, 10]],
'skip': ['backward']}),
('OneHot', {
'block': P.OneHot(),
'desc_const': [3, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)],
@ -1763,7 +1803,7 @@ test_case_other_ops = [
'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('TensorScatterUpdate', {
'block': P.TensorScatterUpdate(),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
@ -1930,11 +1970,10 @@ test_case_other_ops = [
]
test_case_quant_ops = [
('AscendQuant_1', {
'block': inner.AscendQuant(0.5, 0.0, False, "Round"),
'desc_inputs': [Tensor(np.random.rand(1,2,4,4), mstype.float32)],
'desc_inputs': [Tensor(np.random.rand(1, 2, 4, 4), mstype.float32)],
'skip': ['backward']}),
('AscendQuant_2', {
'block': inner.AscendQuant(80.0, 10.0, True, "Round"),
@ -2027,6 +2066,18 @@ raise_set = [
'block': (nn.SSIM(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones((1, 3, 8, 8)), mstype.float32),
Tensor(np.ones((1, 3, 8, 8)), mstype.float32)]}),
('StridedSlice_0', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2.2, 3), (3, 4, 5), (1, 1, 1)],
'desc_inputs': [[4, 5, 6, 7]]}),
('StridedSlice_1', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1)],
'desc_inputs': [[4, 5, 6, 7]]}),
('StridedSlice_2', {
'block': (P.StridedSlice(), {'exception': ValueError}),
'desc_const': [(1, 2, 3), (3, 4, 5), (1, 1, 0)],
'desc_inputs': [[4, 5, 6, 7]]}),
]

View File

@ -25,6 +25,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class NetWorkSlicePositive(Cell):
def __init__(self):
super(NetWorkSlicePositive, self).__init__()
@ -1159,10 +1160,8 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError):
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
ex.value)
def test_tensor_slice_reduce_out_of_bounds_positive():
@ -1177,6 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError):
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)