forked from mindspore-Ecosystem/mindspore
* fix bool index
* change slice setitem to mixed procedure * add testcase for slice assignment
This commit is contained in:
parent
0478b7d191
commit
bc0455eaff
|
@ -167,12 +167,13 @@ def _tensor_getitem(self, index):
|
|||
return tensor_index_by_tensor(self, index)
|
||||
if isinstance(index, tuple):
|
||||
return tensor_index_by_tuple(self, index)
|
||||
# bool type should be judged before int
|
||||
if isinstance(index, bool):
|
||||
return _tensor_index_by_bool(self, index)
|
||||
if isinstance(index, int):
|
||||
return _tensor_index_by_integer(self, index)
|
||||
if isinstance(index, slice):
|
||||
return tensor_index_by_slice(self, index)
|
||||
if isinstance(index, bool):
|
||||
return _tensor_index_by_bool(self, index)
|
||||
if index is None:
|
||||
return F.expand_dims(self, 0)
|
||||
if index is ...:
|
||||
|
@ -206,7 +207,8 @@ def tensor_index_by_slice(data, slice_index):
|
|||
"""Tensor getitem by a single slice"""
|
||||
shape = F.shape(data)
|
||||
if not shape:
|
||||
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.")
|
||||
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor\
|
||||
cannot be 0.")
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
|
||||
return F.strided_slice(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
@ -215,7 +217,11 @@ def _tensor_index_by_integer(data, number):
|
|||
"""Tensor getitem by a single integer number"""
|
||||
shape = F.shape(data)
|
||||
if not shape:
|
||||
const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.")
|
||||
return const_utils.raise_index_error("When tensor is indexed by an integer,\
|
||||
the dimension of the tensor cannot be 0.")
|
||||
if number >= shape[0]:
|
||||
return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format(
|
||||
number, shape[0]))
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
|
||||
shrink_axis_mask = 1
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
@ -427,8 +433,6 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_setitem_by_slice_with_number(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
|
@ -488,8 +492,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_setitem_by_slice_with_tensor(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
|
|
|
@ -339,6 +339,8 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
|||
@constexpr
|
||||
def generate_broadcast_shape(shapes, op_name):
|
||||
"""Generate broadcast shape for a tuple of shape."""
|
||||
if not shapes:
|
||||
return ()
|
||||
broadcast_shape = shapes[0]
|
||||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
|
@ -541,6 +543,11 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
|||
slice_indexes[slice_count].step)
|
||||
# Use list to represent slicing result.
|
||||
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
|
||||
if not indexes_info[pos]:
|
||||
raise IndexError("An empty slice is not supported, got {}:{}:{}".format(
|
||||
slice_indexes[slice_count].start,
|
||||
slice_indexes[slice_count].stop,
|
||||
slice_indexes[slice_count].step))
|
||||
slice_count += 1
|
||||
elif isinstance(ele_type, mstype.ellipsis_type):
|
||||
if ellipsis_num != 0:
|
||||
|
|
|
@ -646,7 +646,7 @@ class TensorAssignWithSlice2(Cell):
|
|||
class TensorAssignWithSlice(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice, self).__init__()
|
||||
self.c = 2
|
||||
self.c = 2.0
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[1:3, ::] = b
|
||||
|
@ -661,7 +661,47 @@ class TensorAssignWithSlice(Cell):
|
|||
return z
|
||||
|
||||
|
||||
def test_tensor_assign():
|
||||
def test_tensor_assign_slice_value_1():
|
||||
net = TensorAssignWithSlice()
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
ck = np.arange(60).reshape(3, 4, 5)
|
||||
b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
|
||||
tb = Tensor(b, dtype=mstype.float32)
|
||||
ta = Tensor(a, dtype=mstype.float32)
|
||||
tck = Tensor(ck, dtype=mstype.float32)
|
||||
out = net(ta, tb, tck)
|
||||
a[1:3, ::] = b
|
||||
a[2:3:, 3:] = b
|
||||
a[::] = b
|
||||
a[::] = 2.0
|
||||
a[::, ::] = b
|
||||
a[::, ::] = 2.0
|
||||
a[2:3:, 0:, 4:1:-1] = b
|
||||
a[2:3:, 0:, 4:1:-1] = 2.0
|
||||
z = a + ck
|
||||
assert np.all(z == out.asnumpy())
|
||||
|
||||
|
||||
def test_tensor_assign_slice_value_2():
|
||||
net2 = TensorAssignWithSlice2()
|
||||
a = np.array([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
ck = np.array([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
|
||||
tb = Tensor(b, dtype=mstype.float32)
|
||||
ta = Tensor(a, dtype=mstype.float32)
|
||||
tck = Tensor(ck, dtype=mstype.float32)
|
||||
out = net2(ta, tb, tck)
|
||||
a[1:5] = b
|
||||
a[3:4] = 5
|
||||
a[-1:1:-1] = b
|
||||
a[-1:3:-1] = 5
|
||||
a[::] = b
|
||||
a[::] = 9
|
||||
z = a + ck
|
||||
assert np.all(z == out.asnumpy())
|
||||
|
||||
|
||||
def test_tensor_assign_exception():
|
||||
net = TensorAssignWithSlice()
|
||||
net2 = TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
|
@ -677,8 +717,6 @@ def test_tensor_assign():
|
|||
Tc = Tensor([], dtype=mstype.float32)
|
||||
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
net(Ta, b, Tck)
|
||||
net2(t, b, tck)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
|
@ -744,9 +782,6 @@ def test_tensor_assign():
|
|||
# 2. A[::, 1:, ...] = scalar/tensor
|
||||
net = TensorAssignWithTupleEllipsis()
|
||||
net(Ta, b)
|
||||
Tc = Tensor(1, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb)
|
||||
|
||||
|
@ -765,7 +800,7 @@ class TensorAssignWithTupleEllipsis(Cell):
|
|||
super(TensorAssignWithTupleEllipsis, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[:2, ...] = 1
|
||||
a[:2, ...] = 1.0
|
||||
a[1:, ...] = b
|
||||
return a
|
||||
|
||||
|
@ -955,3 +990,16 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive():
|
|||
with pytest.raises(ValueError) as ex:
|
||||
net(input_tensor)
|
||||
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
|
||||
|
||||
|
||||
def test_tensor_range():
|
||||
a = np.arange(4*5*6).reshape(4, 5, 6).astype(np.float32)
|
||||
ta = Tensor(a, mstype.float32)
|
||||
ms_out = []
|
||||
for item in ta:
|
||||
ms_out.append(item)
|
||||
np_out = []
|
||||
for item in a:
|
||||
np_out.append(item)
|
||||
for i, elem in enumerate(ms_out):
|
||||
assert np.all(elem.asnumpy() == np_out[i])
|
||||
|
|
|
@ -130,7 +130,7 @@ class TensorAssignWithSlice2(Cell):
|
|||
class TensorAssignWithSlice(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice, self).__init__()
|
||||
self.c = 2
|
||||
self.c = 2.0
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[1:3, ::] = b
|
||||
|
@ -528,8 +528,7 @@ def test_tensor_assign():
|
|||
net = TensorAssignWithTupleEllipsis()
|
||||
net(Ta, b)
|
||||
Tc = Tensor(1, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
net(Ta, Tc)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb)
|
||||
|
||||
|
@ -548,7 +547,7 @@ class TensorAssignWithTupleEllipsis(Cell):
|
|||
super(TensorAssignWithTupleEllipsis, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[:2, ...] = 1
|
||||
a[:2, ...] = 1.0
|
||||
a[1:, ...] = b
|
||||
return a
|
||||
|
||||
|
@ -579,10 +578,10 @@ class TensorAssignWithTupleInteger(Cell):
|
|||
super(TensorAssignWithTupleInteger, self).__init__()
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[(1)] = 1
|
||||
a[(1)] = 1.0
|
||||
a[(1)] = b
|
||||
a[(1, 1)] = b
|
||||
a[(1, 1)] = 1
|
||||
a[(1, 1)] = 1.0
|
||||
z = a + ck
|
||||
return z
|
||||
|
||||
|
|
Loading…
Reference in New Issue