fix & restructure setitem

This commit is contained in:
huangmengxi 2021-04-23 16:28:30 +08:00
parent 1c44e367e0
commit d6be043cb9
9 changed files with 349 additions and 382 deletions

View File

@ -58,9 +58,6 @@ def _tensor_setitem(self, index, value):
if isinstance(index, Tensor):
return tensor_setitem_by_tensor(self, index, value)
if isinstance(index, tuple):
if tuple_indices_have_false(index):
return self
index = format_tuple_indices(index)
return tensor_setitem_by_tuple(self, index, value)
if isinstance(index, bool):
return tensor_setitem_by_bool(self, index, value)
@ -68,7 +65,7 @@ def _tensor_setitem(self, index, value):
return tensor_setitem_by_number(self, index, value)
if isinstance(index, slice):
return tensor_setitem_by_slice(self, index, value)
if index is ...:
if index in (None, ...):
return tensor_setitem_by_ellipsis(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
@ -142,7 +139,7 @@ tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)):
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
@ -242,7 +239,7 @@ def _tensor_index_by_integer(data, int_index):
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
data_shape = F.shape(data)
transformed_number = const_utils.check_and_transform_int_index(int_index, data_shape[0], const_utils.TENSOR_GETITEM)
transformed_number = const_utils.check_range(int_index, data_shape[0])
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
@ -264,10 +261,9 @@ def tensor_index_by_list(data, list_index):
data_shape = F.shape(data)
indexes_types = hyper_map(F.typeof, list_index)
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
if not sub_tuple_index:
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
if tensor_index is False:
const_utils.raise_index_error("Getitem does not support empty list, this will reference shape '0'.")
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
return F.gather(data, tensor_index, 0)
tuple_index_new = ()
@ -341,15 +337,13 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
int_index = const_utils.check_range(index, dim_size)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions += (i,)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tensor_index = const_utils.sequence_to_index(index, dim_size)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions += (i,)
@ -398,6 +392,8 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if len(broadcast_shape) < 2:
broadcast_shape = (1,) + broadcast_shape
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
@ -417,15 +413,13 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
int_index = const_utils.check_range(index, dim_size)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions += (i,)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tensor_index = const_utils.sequence_to_index(index, dim_size)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions += (i,)
@ -435,11 +429,9 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
start, stop, _ = const_utils.slice_to_tuple(index)
start = const_utils.normalize_start(start, dim_size)
stop = const_utils.normalize_stop(stop, dim_size)
if start >= stop:
return None
start, stop, step = const_utils.normalize_slice(index, dim_size)
if const_utils.check_slice_empty(start, stop, step):
return False
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
slice_shapes += (len(slice_ele_list_index),)
tuple_index_new += (slice_ele_list_index,)
@ -522,6 +514,7 @@ def tensor_setitem_by_tensor(self, index, value):
def tensor_setitem_by_tuple(self, index, value):
if isinstance(value, (int, float, bool)):
index = format_tuple_indices(index)
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
@ -622,22 +615,6 @@ def tensor_setitem_by_tensor_with_sequence(data, index, value):
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.shape_mul(data.shape)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
@ -647,30 +624,20 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
def tensor_setitem_by_slice_with_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
return data
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value)
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
tuple_index = ignore_dim_expand(tuple_index)
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, const_utils.TENSOR_GETITEM)
tuple_index, _ = remove_expanded_dims(tuple_index, F.shape(data))
if tuple_index is False:
return data
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
@ -682,37 +649,12 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
if indices is None:
if indices is False:
return data
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.shape_mul(data.shape)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = value.size
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u.astype(data_dtype), data)
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
@ -722,10 +664,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
return data
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
value = _broadcast(value_shape, value)
result = P.TensorScatterUpdate()(data, indices, value.astype(F.dtype(data)))
return result
@ -737,16 +678,17 @@ def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data))
op_name = const_utils.TENSOR_SETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
tuple_index, not_expanded_dim = remove_expanded_dims(tuple_index, F.shape(data))
if tuple_index is False:
return data
value_shape = const_utils.filter_expanded_dims(F.shape(value), not_expanded_dim)
value = F.reshape(value, value_shape)
tuple_index = ignore_dim_expand(tuple_index)
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
@ -763,7 +705,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
new_shape += value.shape
value = F.reshape(value, new_shape)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
if indices is None:
if indices is False:
return data
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
@ -776,9 +718,8 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
def tensor_setitem_by_number_with_number(data, index, value):
"""Assigns the tensor by number with number value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value)
return tensor_setitem_by_number_with_tensor(data, index, value)
def tensor_setitem_by_number_with_sequence(data, index, value):
@ -790,8 +731,10 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
def tensor_setitem_by_number_with_tensor(data, index, value):
"""Assigns the tensor by number with tensor value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
index = const_utils.int_to_index(index, data_shape)
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
value = _broadcast(value_shape, value)
return P.TensorScatterUpdate()(data, index, value)
def tensor_setitem_by_ellipsis_with_number(data, value):
@ -825,8 +768,10 @@ def tensor_setitem_by_bool(data, index, value):
data_shape = F.shape(data)
if not index:
data_shape = (0,) + data_shape
if not isinstance(value, Tensor):
if isinstance(value, (list, tuple)):
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
elif isinstance(value, (int, float, bool)):
value = const_utils.make_tensor(value)
value_shape = F.shape(value)
source_shape = const_utils.get_source_shape(data_shape, value_shape)
if index:
@ -851,8 +796,7 @@ def format_list_indices(list_indices, length):
# If eyery element in list is bool, it's treated as 1-D bool tensor.
# If every element in list is int(not all bool), it's treated as int tensor.
if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)):
list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM)
return const_utils.make_tensor(list_indices)
return const_utils.sequence_to_index(list_indices, length)
# If list contains other types(.../list/tuple/None), it's treated as a tuple
return const_utils.deep_tuple(list_indices)
@ -871,64 +815,69 @@ def format_tuple_indices(tuple_indices):
return res
def tuple_indices_have_false(tuple_indices):
"""Returns True if tuple_indices contains False."""
for i in tuple_indices:
if i is False:
return True
return False
def remove_expanded_dims(tuple_index, data_shape):
"""Removes expanded dimensions in tuple_index and value."""
op_name = const_utils.TENSOR_SETITEM
not_expanded_dim = ()
shapes = ()
has_true = False
has_false = False
has_sequence = False
indices_out = () # with dimension expansion indices removed
idx_tensor = -1 # index of the previous tensor
idx_advanced = -1 # index of the first advanced index in expanded tensor
cur_dim = 0 # current dimension of the data to be indexed
for i, v in enumerate(tuple_index):
index_out = format_index(v, data_shape, cur_dim)
def ignore_dim_expand(idx):
"""Filters flags for dimension expansion from idx."""
res = ()
for i in idx:
if not i is True and not i is None:
res += (i,)
if not res:
res = (True,)
return res
def remove_ignored_dim(idx, value_shape, data_rank):
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
has_ellipsis = False
has_leading_true = False
has_trailing_true = False
cnt_leading_expanded = 0
cnt_trailing_expanded = 0
cnt_not_dim_expand = 0
for i in idx:
if i is True:
if has_ellipsis:
has_trailing_true = True
else:
has_leading_true = True
elif i is None:
if has_ellipsis:
cnt_trailing_expanded += 1
else:
cnt_leading_expanded += 1
if index_out is None:
not_expanded_dim += (False,)
elif const_utils.is_slice(index_out):
indices_out += (index_out,)
not_expanded_dim += (True,)
start, stop, step = const_utils.normalize_slice(index_out, data_shape[cur_dim])
if const_utils.check_slice_empty(start, stop, step):
has_false = True
cur_dim += 1
elif isinstance(index_out, (Tensor, bool)): # advanced index
if idx_advanced == -1:
idx_advanced = len(not_expanded_dim)
elif i - idx_tensor > 1:
idx_advanced = 0
idx_tensor = i
if isinstance(index_out, Tensor):
if F.rank(index_out) > 0:
has_sequence = True
indices_out += (index_out,)
shapes += (F.shape(index_out),)
cur_dim += 1
has_true = has_true or index_out is True
has_false = has_false or index_out is False
else:
if const_utils.is_ellipsis(i):
has_ellipsis = True
cnt_not_dim_expand += 1
if cnt_not_dim_expand + 1 < data_rank:
if has_leading_true:
cnt_leading_expanded += 1
elif has_trailing_true:
cnt_trailing_expanded += 1
const_utils.raise_index_error('invalid index type')
value_starting_pos = 0
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1:
value_starting_pos += 1
cnt_leading_expanded -= 1
broadcast_shape = const_utils.generate_broadcast_shape(shapes, op_name)
if has_false:
if F.shape_mul(broadcast_shape) != 1:
const_utils.raise_index_error('unable to broadcast indices')
return False, not_expanded_dim
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
value_expanded_not_unit = False
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None):
if i != 1:
value_expanded_not_unit = True
if value_expanded_pos < 0 or value_expanded_not_unit:
const_utils.raise_value_error('shape mismatch')
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos)
expand_true = has_true and not(has_false or has_sequence) # whether to expand dimension at True
tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
not_expanded_dim = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
rem_ndim, not_expanded_dim)
if not indices_out:
indices_out = (True,)
return indices_out, not_expanded_dim
def format_index(idx, data_shape, cur_dim):
"""Converts advanced index into tensor."""
if isinstance(idx, (tuple, list)):
idx = const_utils.sequence_to_index(idx, data_shape[cur_dim])
elif isinstance(idx, int) and not isinstance(idx, bool):
idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
return idx

View File

@ -14,7 +14,7 @@
# ============================================================================
"""constexpr util"""
from functools import reduce
from itertools import compress
import numpy as np
@ -75,10 +75,12 @@ def make_empty_slice():
@constexpr
def _deep_list(array_like):
def _deep_list(array_like, ndim=-1):
"""convert nested tuple/list mixtures to pure nested list"""
if ndim != -1:
check_range(array_like, ndim)
if isinstance(array_like, (list, tuple)):
return list(map(_deep_list, array_like))
return list(map(lambda x: _deep_list(x, ndim), array_like))
return array_like
@ -115,7 +117,16 @@ def _deep_tensor_to_nparray(array_like):
@constexpr
def make_tensor(a, dtype=mstype.int32, data_shape=None):
def check_range(x, ndim):
if isinstance(x, int) and not isinstance(x, bool):
if x >= ndim or x < -ndim:
raise IndexError(f'index {x} if out of bounds for dimension with size {ndim}')
x = x%ndim
return x
@constexpr
def make_tensor(a, dtype=mstype.int64, data_shape=None, ndim=-1):
"""
Converts the input to tensor.
@ -133,16 +144,18 @@ def make_tensor(a, dtype=mstype.int32, data_shape=None):
TypeError: If input arguments have types not specified above.
ValueError: If input `a` has different sizes at different dimensions.
"""
if data_shape:
return Tensor(np.zeros(data_shape), dtype)
if not isinstance(a, (list, tuple, int, float, bool)):
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
if ndim != -1:
check_range(a, ndim)
if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
a = _deep_list(a, ndim)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = np.asarray(a)
@ -292,51 +305,6 @@ def get_pos_of_indexes_types(indexes_types, op_name):
tensor_positions, sequence_positions
def slice_expand(input_slices, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a sensor is an integer element tuple.
Outputs:
tuple[list], This is expressed as (begins, ends, strides).
"""
begin, end, strides = [], [], []
index = 0
slices = None
# Slice or tuple(Slice...)
if isinstance(input_slices, slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))):
is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, type(...)):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else:
raise IndexError("Tensor's index type is not supported yet.")
for s in slices:
start = 0 if (s.start is None) else s.start
stop = shape[index] if (s.stop is None) else s.stop
step = 1 if (s.step is None) else s.step
begin.append(start)
end.append(stop)
strides.append(step)
index += 1
while index < len(shape):
begin.append(0)
end.append(shape[index])
strides.append(1)
index += 1
return begin, end, strides
def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
@ -358,30 +326,24 @@ def ellipsis2slice(input_, shape):
@constexpr
def slice2indices(input_slices, shape):
def slice2indices(input_slice, shape):
"""
Converts slice to indices.
Inputs:
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a tensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
"""
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = normalize_start(begin[i], element)
e = normalize_stop(end[i], element)
if s >= e:
return False
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
ravel = np.ravel_multi_index(np_ix, shape)
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
return ravel
start, stop, step = normalize_slice(input_slice, shape[0])
if check_slice_empty(start, stop, step):
return False
grids = ([np.array(list(range(start, stop, step)), dtype=np.int64)] +
[np.array(list(range(dim_size)), dtype=np.int64) for dim_size in shape[1:]])
mesh = np.ix_(*grids)
return Tensor(np.stack(np.broadcast_arrays(*mesh), axis=-1))
@constexpr
@ -406,29 +368,6 @@ def check_indices_value_size(indices_size, value_size):
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_int(indexes):
"""Judges tuple element type."""
if not indexes:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexes, tuple):
for _, ele in enumerate(indexes):
if not isinstance(ele, int):
return False
return True
return False
@constexpr
def tuple_index_int_cnt(types, op_name):
"""count the int type of types which contains the tuple elements' type."""
@ -439,12 +378,9 @@ def tuple_index_int_cnt(types, op_name):
@constexpr
def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
basic_cnt = sum(isinstance(
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
if tensor_cnt == len(types):
if all(isinstance(ele, mstype.tensor_type) for ele in types):
return ALL_TENSOR
if basic_cnt == len(types):
if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types):
return ALL_BASIC
return MIXED
@ -501,19 +437,10 @@ def generate_broadcast_shape(shapes, op_name):
@constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check two shapes need broadcast."""
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
if len(shape_y) > len(shape_x):
raise error
for i in range(-len(shape_y), 0):
if shape_y[i] > shape_x[i]:
raise error
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
raise error
if shape_y == shape_x:
return False
return True
"""Check shape_y needs to be broadcast to shape_x."""
if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))):
raise ValueError(f"{shape_y} could not broadcast with {shape_x}.")
return shape_y != shape_x
@constexpr
@ -523,53 +450,12 @@ def compute_multiples(origin_shape, broadcast_shape):
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
@constexpr
def compute_new_shape(origin_shape, indexes_shapes_info):
"""Compute new shape between origin shape with final shape."""
new_shape = []
for i in indexes_shapes_info:
if i == origin_shape:
new_shape.extend(origin_shape)
else:
new_shape.append(1)
return tuple(new_shape)
@constexpr
def convert_int_to_slice(tuple_index):
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
return tuple_index_new
@constexpr
def check_and_transform_int_index(index, shape, op_name):
if index < -shape or index >= shape:
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
f"the corresponding dim length, but get {index}.")
if index < 0:
index += shape
return index
@constexpr
def transform_sequence_index(sequence_index, shape, op_name):
"""transform list or tuple with integer and boolean to tuple with integer index"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
if int_count == 0 and bool_count != 0:
if bool_count == shape:
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimension")
else:
list_index = [int(index) for index in sequence_index]
for i, index in enumerate(list_index):
list_index[i] = check_and_transform_int_index(index, shape, op_name)
sub_tuple_index = tuple(list_index)
return sub_tuple_index
@constexpr
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
"""Convert a slice to a tensor."""
@ -585,16 +471,6 @@ def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slic
return slice_index_tensor
@constexpr
def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
f"value tuple is not same as the first tensor shape.")
return True
@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
@ -605,18 +481,6 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
@constexpr
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps))
@constexpr
def generate_updates_shape(data_shape, index_shape, op_type):
"""Generate updates shape for 'tensor setitem'."""
@ -679,9 +543,7 @@ def scalar_in_sequence(x, y):
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the sequence is not.")
if x in y:
return True
return False
return x in y
@constexpr
@ -782,12 +644,6 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
@constexpr
def mstype_eq(x, y):
"""Determine whether the input `x` equals `y`."""
return x == y
@constexpr
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
@ -801,11 +657,6 @@ def unpack(x):
return x
@constexpr
def slice_to_tuple(s):
return (s.start, s.stop, s.step)
@constexpr
def normalize_start(start, dim_size):
"""
@ -833,8 +684,20 @@ def normalize_stop(stop, dim_size):
@constexpr
def is_ellipsis(x):
return x is Ellipsis
def normalize_slice(input_slice, dim_size):
"""Normalizes start, stop, step in a slice."""
start = normalize_start(input_slice.start, dim_size)
stop = normalize_stop(input_slice.stop, dim_size)
step = input_slice.step
if step is None:
step = 1
if step >= 0:
start = normalize_start(input_slice.start, dim_size)
stop = normalize_stop(input_slice.stop, dim_size)
else:
start = normalize_stop(input_slice.start, dim_size)
stop = normalize_start(input_slice.stop, dim_size)
return start, stop, step
@constexpr
@ -869,3 +732,64 @@ def sequence_mul_int(seq, number):
def check_in_sequence(x, y):
"""Determine whether the input `x` is in the sequence `y`."""
return x in y
@constexpr
def is_slice(x):
return isinstance(x, slice)
@constexpr
def filter_expanded_dims(shape, not_expanded_dim):
diff = len(not_expanded_dim) - len(shape)
if diff < 0:
raise ValueError('unable to broadcast {shape}')
return tuple(compress(shape, not_expanded_dim[diff:]))
@constexpr
def sequence_to_index(sequence, dim_size):
"""Transforms sequence to tensor index."""
if not sequence:
return False
if all(isinstance(i, bool) for i in sequence):
seq_size = len(sequence)
if seq_size != dim_size:
raise IndexError('dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
sequence = tuple(compress(range(dim_size), sequence))
if not sequence:
return False
return make_tensor(sequence, mstype.int64, None, dim_size)
@constexpr
def int_to_index(i, shape):
"""Converts integer to tensor indices."""
dim_size = shape[0]
if i < -dim_size or i >= dim_size:
raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
i = i%dim_size
if len(shape) == 1:
return Tensor([[i]])
grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]]
mesh = np.ix_(*grids)
index = np.stack(np.broadcast_arrays(*mesh), -1)
return Tensor(np.insert(index, 0, i, -1))
@constexpr
def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
"""Adds remaining dimensions not indexed to not_expanded_dim"""
if idx_advanced != -1:
if expand_true:
# tensor indices generate only one dimension with size 1
tensor_dims = (False,)
else:
tensor_dims = (True,)*tensor_index_ndim
not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
return not_expanded_dim + (True,)*rem_ndim
@constexpr
def check_slice_empty(start, stop, step):
return (start - stop)*step >= 0

View File

@ -68,7 +68,7 @@ def _equal_mstype(x, y):
Returns:
bool, if x == y return true, x != y return false.
"""
return const_utils.mstype_eq(x, y)
return const_utils.is_same_type(x, y)
@equal.register("String", "String")

View File

@ -62,3 +62,17 @@ def _logical_not_tuple(x):
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())
@logical_not.register("List")
def _logical_not_list(x):
"""
Return logical not operation result of a list object.
Args:
x(List): The input tuple.
Returns:
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())

View File

@ -54,7 +54,7 @@ def _not_equal_mstype(x, y):
Returns:
bool, if x != y return true, x == y return false.
"""
return not const_utils.mstype_eq(x, y)
return not const_utils.is_same_type(x, y)
@not_equal.register("String", "String")

View File

@ -214,9 +214,6 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
@ -238,9 +235,6 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
@ -263,9 +257,6 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)
@ -288,9 +279,6 @@ def _tensor_setitem_by_tuple_with_list(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)
@ -587,6 +575,86 @@ def _tensor_setitem_by_ellipsis_with_tuple(data, index, value):
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
@setitem.register("Tensor", "None", "Number")
def _tensor_setitem_by_none_with_number(data, index, value):
"""
Tensor assignment.
Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Number.
Inputs:
data (Tensor): Assigned tensor.
index (None): Index is ``...``.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value)
@setitem.register("Tensor", "None", "Tensor")
def _tensor_setitem_by_none_with_tensor(data, index, value):
"""
Tensor assignment.
Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (None): Index is ``...``.
value (Tensor): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value)
@setitem.register("Tensor", "None", "List")
def _tensor_setitem_by_none_with_list(data, index, value):
"""
Tensor assignment.
Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a List, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (None): Index is ``...``.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
@setitem.register("Tensor", "None", "Tuple")
def _tensor_setitem_by_none_with_tuple(data, index, value):
"""
Tensor assignment.
Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Tuple, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (None): Index is ``...``.
value (Number): Assignment value.
Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
@setitem.register("Tensor", "List", "Number")
def _tensor_setitem_by_list_with_number(data, index, value):
"""
@ -608,9 +676,6 @@ def _tensor_setitem_by_list_with_number(data, index, value):
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value)
@ -635,9 +700,6 @@ def _tensor_setitem_by_list_with_tensor(data, index, value):
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value)
@ -662,9 +724,6 @@ def _tensor_setitem_by_list_with_tuple(data, index, value):
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)
@ -689,7 +748,4 @@ def _tensor_setitem_by_list_with_list(data, index, value):
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)

View File

@ -772,8 +772,11 @@ def test_tensor_assign_slice_value_2():
def test_tensor_assign_exception():
net = TensorAssignWithSlice()
net2 = TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1()
net_e2 = TensorAssignWithSliceError2()
# The test case is no longer appropriate since x[1:3:-1] = np.array(2) does
# not incur an error in numpy, which leaves the original array unchanged after
# the assign operation.
# net_e1 = TensorAssignWithSliceError1()
# net_e2 = TensorAssignWithSliceError2()
a = np.arange(60).reshape(3, 4, 5)
ck = np.arange(60).reshape(3, 4, 5)
b = Tensor([1], dtype=mstype.float32)
@ -787,8 +790,8 @@ def test_tensor_assign_exception():
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(ValueError):
net_e2(t, 2)
# with pytest.raises(ValueError):
# net_e2(t, 2)
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
@ -809,13 +812,13 @@ def test_tensor_assign_exception():
with pytest.raises(ValueError):
net(Ta, Tb, Tck)
# 3. A[Tuple(Slice...)] = U, Slice error
with pytest.raises(IndexError):
net_e1(Ta, b)
# with pytest.raises(IndexError):
# net_e1(Ta, b)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with pytest.raises(IndexError):
net_e1(Ta, 2)
# with pytest.raises(IndexError):
# net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor

View File

@ -195,6 +195,7 @@ def test_setitem_by_slice():
x[5:0:3] = 5
x[5:5:5] = 6
x[-1:2] = 7
x[1:0:-1] = 8
return x
setup_testcase(x, cases)
@ -214,3 +215,19 @@ def test_setitem_by_tuple_of_slices():
x[1:1, 2:2] = 6
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_dim_expand():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[None, True, [1, 0], (False, True, True), [2]] = 2
x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]]
x[0:1] = [[2, 3, 4, 5]]
x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]]
return x
setup_testcase(x, cases)

View File

@ -448,8 +448,11 @@ def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
net2 = TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1()
net_e2 = TensorAssignWithSliceError2()
# The test case is no longer appropriate since x[1:3:-1] = np.array(2) does
# not incur an error in numpy, which leaves the original array unchanged after
# the assign operation.
# net_e1 = TensorAssignWithSliceError1()
# net_e2 = TensorAssignWithSliceError2()
a = np.arange(60).reshape(3, 4, 5)
ck = np.arange(60).reshape(3, 4, 5)
b = Tensor([1], dtype=mstype.float32)
@ -465,8 +468,9 @@ def test_tensor_assign():
net2(t, b, tck)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, 0 in shape
with pytest.raises(ValueError):
net_e2(t, Tensor(2, mstype.int32))
# with pytest.raises(ValueError):
# net_e2(t, Tensor(2, mstype.int32))
# Error for A[Slice] = U, U is a Tensor
# 1. A[Slice] = U, u.size is error
@ -487,13 +491,13 @@ def test_tensor_assign():
with pytest.raises(ValueError):
net(Ta, Tb, Tck)
# 3. A[Tuple(Slice...)] = U, Slice error
with pytest.raises(IndexError):
net_e1(Ta, b)
# with pytest.raises(IndexError):
# net_e1(Ta, b)
# Error for A[Tuple(Slice...)] = Number
# 1. A[Tuple(Slice...)] = Number, Slice error
with pytest.raises(IndexError):
net_e1(Ta, Tensor(2, mstype.int32))
# with pytest.raises(IndexError):
# net_e1(Ta, Tensor(2, mstype.int32))
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor