forked from mindspore-Ecosystem/mindspore
fix & restructure setitem
This commit is contained in:
parent
1c44e367e0
commit
d6be043cb9
|
@ -58,9 +58,6 @@ def _tensor_setitem(self, index, value):
|
|||
if isinstance(index, Tensor):
|
||||
return tensor_setitem_by_tensor(self, index, value)
|
||||
if isinstance(index, tuple):
|
||||
if tuple_indices_have_false(index):
|
||||
return self
|
||||
index = format_tuple_indices(index)
|
||||
return tensor_setitem_by_tuple(self, index, value)
|
||||
if isinstance(index, bool):
|
||||
return tensor_setitem_by_bool(self, index, value)
|
||||
|
@ -68,7 +65,7 @@ def _tensor_setitem(self, index, value):
|
|||
return tensor_setitem_by_number(self, index, value)
|
||||
if isinstance(index, slice):
|
||||
return tensor_setitem_by_slice(self, index, value)
|
||||
if index is ...:
|
||||
if index in (None, ...):
|
||||
return tensor_setitem_by_ellipsis(self, index, value)
|
||||
|
||||
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
|
||||
|
@ -142,7 +139,7 @@ tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
|
|||
|
||||
def _broadcast(broadcast_shape, x):
|
||||
"""Broadcast tensor to the required shape."""
|
||||
if F.shape(x) == broadcast_shape:
|
||||
if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)):
|
||||
return x
|
||||
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
|
||||
if multiples:
|
||||
|
@ -242,7 +239,7 @@ def _tensor_index_by_integer(data, int_index):
|
|||
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
transformed_number = const_utils.check_and_transform_int_index(int_index, data_shape[0], const_utils.TENSOR_GETITEM)
|
||||
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
||||
shrink_axis_mask = 1
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
@ -264,10 +261,9 @@ def tensor_index_by_list(data, list_index):
|
|||
data_shape = F.shape(data)
|
||||
indexes_types = hyper_map(F.typeof, list_index)
|
||||
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
|
||||
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
|
||||
if not sub_tuple_index:
|
||||
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
||||
if tensor_index is False:
|
||||
const_utils.raise_index_error("Getitem does not support empty list, this will reference shape '0'.")
|
||||
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
|
||||
return F.gather(data, tensor_index, 0)
|
||||
|
||||
tuple_index_new = ()
|
||||
|
@ -341,15 +337,13 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|||
|
||||
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
||||
if i in int_positions:
|
||||
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
|
||||
int_index = const_utils.check_range(index, dim_size)
|
||||
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions += (i,)
|
||||
elif i in sequence_positions:
|
||||
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
|
||||
tensor_index = const_utils.make_tensor(sequence_index)
|
||||
tensor_index = F.cast(tensor_index, mstype.int64)
|
||||
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions += (i,)
|
||||
|
@ -398,6 +392,8 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
|
|||
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
|
||||
tensor_index_shape = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
|
||||
if len(broadcast_shape) < 2:
|
||||
broadcast_shape = (1,) + broadcast_shape
|
||||
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
|
||||
new_broadcast_tensors = ()
|
||||
for tensor in broadcast_tensors:
|
||||
|
@ -417,15 +413,13 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
|
|||
|
||||
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
||||
if i in int_positions:
|
||||
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
|
||||
int_index = const_utils.check_range(index, dim_size)
|
||||
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions += (i,)
|
||||
elif i in sequence_positions:
|
||||
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
|
||||
tensor_index = const_utils.make_tensor(sequence_index)
|
||||
tensor_index = F.cast(tensor_index, mstype.int64)
|
||||
tensor_index = const_utils.sequence_to_index(index, dim_size)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions += (i,)
|
||||
|
@ -435,11 +429,9 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
|
|||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
elif i in slice_positions:
|
||||
start, stop, _ = const_utils.slice_to_tuple(index)
|
||||
start = const_utils.normalize_start(start, dim_size)
|
||||
stop = const_utils.normalize_stop(stop, dim_size)
|
||||
if start >= stop:
|
||||
return None
|
||||
start, stop, step = const_utils.normalize_slice(index, dim_size)
|
||||
if const_utils.check_slice_empty(start, stop, step):
|
||||
return False
|
||||
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
|
||||
slice_shapes += (len(slice_ele_list_index),)
|
||||
tuple_index_new += (slice_ele_list_index,)
|
||||
|
@ -522,6 +514,7 @@ def tensor_setitem_by_tensor(self, index, value):
|
|||
|
||||
def tensor_setitem_by_tuple(self, index, value):
|
||||
if isinstance(value, (int, float, bool)):
|
||||
index = format_tuple_indices(index)
|
||||
return tensor_setitem_by_tuple_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_tuple_with_tensor(self, index, value)
|
||||
|
@ -622,22 +615,6 @@ def tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|||
return _tensor_setitem_by_tensor_with_sequence(data, index, value)
|
||||
|
||||
|
||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||
"""Assigns a scalar value to the tensor."""
|
||||
data_size = F.shape_mul(data.shape)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
||||
"""Set a tensor item by a tensor with a tuple."""
|
||||
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
|
@ -647,30 +624,20 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
|
|||
|
||||
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
||||
"""Givens a scalar assign to tensor by slice"""
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
if indices is False:
|
||||
return data
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value)
|
||||
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||
"""Assigns the tensor by tuple with number value."""
|
||||
tuple_index = ignore_dim_expand(tuple_index)
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, const_utils.TENSOR_GETITEM)
|
||||
tuple_index, _ = remove_expanded_dims(tuple_index, F.shape(data))
|
||||
if tuple_index is False:
|
||||
return data
|
||||
|
||||
if len(tuple_index) == 1:
|
||||
data[tuple_index[0]] = value
|
||||
return data
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
@ -682,37 +649,12 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
if int_cnt == const_utils.ALL_INT:
|
||||
tuple_index = const_utils.convert_int_to_slice(tuple_index)
|
||||
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
|
||||
if indices is None:
|
||||
if indices is False:
|
||||
return data
|
||||
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return P.TensorScatterUpdate()(data, indices, updates)
|
||||
|
||||
|
||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||
"""Assigns a tensor value to the tensor."""
|
||||
data_size = F.shape_mul(data.shape)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = None
|
||||
value_size = value.size
|
||||
|
||||
value_size = const_utils.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u.astype(data_dtype), data)
|
||||
|
||||
|
||||
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
|
@ -722,10 +664,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
if indices is False:
|
||||
return data
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
|
||||
value = _broadcast(value_shape, value)
|
||||
result = P.TensorScatterUpdate()(data, indices, value.astype(F.dtype(data)))
|
||||
return result
|
||||
|
||||
|
||||
|
@ -737,16 +678,17 @@ def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
|
|||
|
||||
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||
"""Assigns the tensor by tuple with tensor value."""
|
||||
value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data))
|
||||
op_name = const_utils.TENSOR_SETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
tuple_index, not_expanded_dim = remove_expanded_dims(tuple_index, F.shape(data))
|
||||
if tuple_index is False:
|
||||
return data
|
||||
value_shape = const_utils.filter_expanded_dims(F.shape(value), not_expanded_dim)
|
||||
value = F.reshape(value, value_shape)
|
||||
tuple_index = ignore_dim_expand(tuple_index)
|
||||
|
||||
if len(tuple_index) == 1:
|
||||
data[tuple_index[0]] = value
|
||||
return data
|
||||
op_name = const_utils.TENSOR_GETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
data, tuple_index = _expand_data_dims(data, tuple_index)
|
||||
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
@ -763,7 +705,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
new_shape += value.shape
|
||||
value = F.reshape(value, new_shape)
|
||||
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
|
||||
if indices is None:
|
||||
if indices is False:
|
||||
return data
|
||||
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return P.TensorScatterUpdate()(data, indices, updates)
|
||||
|
@ -776,9 +718,8 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
|
|||
|
||||
def tensor_setitem_by_number_with_number(data, index, value):
|
||||
"""Assigns the tensor by number with number value."""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||
value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value)
|
||||
return tensor_setitem_by_number_with_tensor(data, index, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_number_with_sequence(data, index, value):
|
||||
|
@ -790,8 +731,10 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
|
|||
def tensor_setitem_by_number_with_tensor(data, index, value):
|
||||
"""Assigns the tensor by number with tensor value."""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||
index = const_utils.int_to_index(index, data_shape)
|
||||
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
||||
value = _broadcast(value_shape, value)
|
||||
return P.TensorScatterUpdate()(data, index, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_ellipsis_with_number(data, value):
|
||||
|
@ -825,8 +768,10 @@ def tensor_setitem_by_bool(data, index, value):
|
|||
data_shape = F.shape(data)
|
||||
if not index:
|
||||
data_shape = (0,) + data_shape
|
||||
if not isinstance(value, Tensor):
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
value = const_utils.make_tensor(value)
|
||||
value_shape = F.shape(value)
|
||||
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
||||
if index:
|
||||
|
@ -851,8 +796,7 @@ def format_list_indices(list_indices, length):
|
|||
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
||||
# If every element in list is int(not all bool), it's treated as int tensor.
|
||||
if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)):
|
||||
list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM)
|
||||
return const_utils.make_tensor(list_indices)
|
||||
return const_utils.sequence_to_index(list_indices, length)
|
||||
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
||||
return const_utils.deep_tuple(list_indices)
|
||||
|
||||
|
@ -871,64 +815,69 @@ def format_tuple_indices(tuple_indices):
|
|||
return res
|
||||
|
||||
|
||||
def tuple_indices_have_false(tuple_indices):
|
||||
"""Returns True if tuple_indices contains False."""
|
||||
for i in tuple_indices:
|
||||
if i is False:
|
||||
return True
|
||||
return False
|
||||
def remove_expanded_dims(tuple_index, data_shape):
|
||||
"""Removes expanded dimensions in tuple_index and value."""
|
||||
op_name = const_utils.TENSOR_SETITEM
|
||||
not_expanded_dim = ()
|
||||
shapes = ()
|
||||
has_true = False
|
||||
has_false = False
|
||||
has_sequence = False
|
||||
indices_out = () # with dimension expansion indices removed
|
||||
idx_tensor = -1 # index of the previous tensor
|
||||
idx_advanced = -1 # index of the first advanced index in expanded tensor
|
||||
cur_dim = 0 # current dimension of the data to be indexed
|
||||
|
||||
for i, v in enumerate(tuple_index):
|
||||
index_out = format_index(v, data_shape, cur_dim)
|
||||
|
||||
def ignore_dim_expand(idx):
|
||||
"""Filters flags for dimension expansion from idx."""
|
||||
res = ()
|
||||
for i in idx:
|
||||
if not i is True and not i is None:
|
||||
res += (i,)
|
||||
if not res:
|
||||
res = (True,)
|
||||
return res
|
||||
|
||||
|
||||
def remove_ignored_dim(idx, value_shape, data_rank):
|
||||
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
|
||||
has_ellipsis = False
|
||||
has_leading_true = False
|
||||
has_trailing_true = False
|
||||
cnt_leading_expanded = 0
|
||||
cnt_trailing_expanded = 0
|
||||
cnt_not_dim_expand = 0
|
||||
for i in idx:
|
||||
if i is True:
|
||||
if has_ellipsis:
|
||||
has_trailing_true = True
|
||||
else:
|
||||
has_leading_true = True
|
||||
elif i is None:
|
||||
if has_ellipsis:
|
||||
cnt_trailing_expanded += 1
|
||||
else:
|
||||
cnt_leading_expanded += 1
|
||||
if index_out is None:
|
||||
not_expanded_dim += (False,)
|
||||
elif const_utils.is_slice(index_out):
|
||||
indices_out += (index_out,)
|
||||
not_expanded_dim += (True,)
|
||||
start, stop, step = const_utils.normalize_slice(index_out, data_shape[cur_dim])
|
||||
if const_utils.check_slice_empty(start, stop, step):
|
||||
has_false = True
|
||||
cur_dim += 1
|
||||
elif isinstance(index_out, (Tensor, bool)): # advanced index
|
||||
if idx_advanced == -1:
|
||||
idx_advanced = len(not_expanded_dim)
|
||||
elif i - idx_tensor > 1:
|
||||
idx_advanced = 0
|
||||
idx_tensor = i
|
||||
if isinstance(index_out, Tensor):
|
||||
if F.rank(index_out) > 0:
|
||||
has_sequence = True
|
||||
indices_out += (index_out,)
|
||||
shapes += (F.shape(index_out),)
|
||||
cur_dim += 1
|
||||
has_true = has_true or index_out is True
|
||||
has_false = has_false or index_out is False
|
||||
else:
|
||||
if const_utils.is_ellipsis(i):
|
||||
has_ellipsis = True
|
||||
cnt_not_dim_expand += 1
|
||||
if cnt_not_dim_expand + 1 < data_rank:
|
||||
if has_leading_true:
|
||||
cnt_leading_expanded += 1
|
||||
elif has_trailing_true:
|
||||
cnt_trailing_expanded += 1
|
||||
const_utils.raise_index_error('invalid index type')
|
||||
|
||||
value_starting_pos = 0
|
||||
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1:
|
||||
value_starting_pos += 1
|
||||
cnt_leading_expanded -= 1
|
||||
broadcast_shape = const_utils.generate_broadcast_shape(shapes, op_name)
|
||||
if has_false:
|
||||
if F.shape_mul(broadcast_shape) != 1:
|
||||
const_utils.raise_index_error('unable to broadcast indices')
|
||||
return False, not_expanded_dim
|
||||
|
||||
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
|
||||
value_expanded_not_unit = False
|
||||
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None):
|
||||
if i != 1:
|
||||
value_expanded_not_unit = True
|
||||
if value_expanded_pos < 0 or value_expanded_not_unit:
|
||||
const_utils.raise_value_error('shape mismatch')
|
||||
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos)
|
||||
expand_true = has_true and not(has_false or has_sequence) # whether to expand dimension at True
|
||||
tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
|
||||
rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
|
||||
not_expanded_dim = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim,
|
||||
rem_ndim, not_expanded_dim)
|
||||
|
||||
if not indices_out:
|
||||
indices_out = (True,)
|
||||
return indices_out, not_expanded_dim
|
||||
|
||||
|
||||
def format_index(idx, data_shape, cur_dim):
|
||||
"""Converts advanced index into tensor."""
|
||||
if isinstance(idx, (tuple, list)):
|
||||
idx = const_utils.sequence_to_index(idx, data_shape[cur_dim])
|
||||
elif isinstance(idx, int) and not isinstance(idx, bool):
|
||||
idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
|
||||
return idx
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""constexpr util"""
|
||||
|
||||
from functools import reduce
|
||||
from itertools import compress
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -75,10 +75,12 @@ def make_empty_slice():
|
|||
|
||||
|
||||
@constexpr
|
||||
def _deep_list(array_like):
|
||||
def _deep_list(array_like, ndim=-1):
|
||||
"""convert nested tuple/list mixtures to pure nested list"""
|
||||
if ndim != -1:
|
||||
check_range(array_like, ndim)
|
||||
if isinstance(array_like, (list, tuple)):
|
||||
return list(map(_deep_list, array_like))
|
||||
return list(map(lambda x: _deep_list(x, ndim), array_like))
|
||||
return array_like
|
||||
|
||||
|
||||
|
@ -115,7 +117,16 @@ def _deep_tensor_to_nparray(array_like):
|
|||
|
||||
|
||||
@constexpr
|
||||
def make_tensor(a, dtype=mstype.int32, data_shape=None):
|
||||
def check_range(x, ndim):
|
||||
if isinstance(x, int) and not isinstance(x, bool):
|
||||
if x >= ndim or x < -ndim:
|
||||
raise IndexError(f'index {x} if out of bounds for dimension with size {ndim}')
|
||||
x = x%ndim
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def make_tensor(a, dtype=mstype.int64, data_shape=None, ndim=-1):
|
||||
"""
|
||||
Converts the input to tensor.
|
||||
|
||||
|
@ -133,16 +144,18 @@ def make_tensor(a, dtype=mstype.int32, data_shape=None):
|
|||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If input `a` has different sizes at different dimensions.
|
||||
"""
|
||||
|
||||
if data_shape:
|
||||
return Tensor(np.zeros(data_shape), dtype)
|
||||
|
||||
if not isinstance(a, (list, tuple, int, float, bool)):
|
||||
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
|
||||
|
||||
if ndim != -1:
|
||||
check_range(a, ndim)
|
||||
|
||||
if isinstance(a, (list, tuple)):
|
||||
# Convert all tuple/nested tuples to lists
|
||||
a = _deep_list(a)
|
||||
a = _deep_list(a, ndim)
|
||||
# Convert all tensor sub-elements to numpy arrays
|
||||
a = _deep_tensor_to_nparray(a)
|
||||
a = np.asarray(a)
|
||||
|
@ -292,51 +305,6 @@ def get_pos_of_indexes_types(indexes_types, op_name):
|
|||
tensor_positions, sequence_positions
|
||||
|
||||
|
||||
def slice_expand(input_slices, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a sensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
tuple[list], This is expressed as (begins, ends, strides).
|
||||
"""
|
||||
begin, end, strides = [], [], []
|
||||
index = 0
|
||||
slices = None
|
||||
# Slice or tuple(Slice...)
|
||||
if isinstance(input_slices, slice):
|
||||
slices = (input_slices,)
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))):
|
||||
is_have_ellipsis = False
|
||||
for _, element in enumerate(input_slices):
|
||||
if isinstance(element, type(...)):
|
||||
is_have_ellipsis = True
|
||||
break
|
||||
if is_have_ellipsis:
|
||||
slices = ellipsis2slice(input_slices, shape)
|
||||
else:
|
||||
slices = input_slices
|
||||
else:
|
||||
raise IndexError("Tensor's index type is not supported yet.")
|
||||
for s in slices:
|
||||
start = 0 if (s.start is None) else s.start
|
||||
stop = shape[index] if (s.stop is None) else s.stop
|
||||
step = 1 if (s.step is None) else s.step
|
||||
begin.append(start)
|
||||
end.append(stop)
|
||||
strides.append(step)
|
||||
index += 1
|
||||
while index < len(shape):
|
||||
begin.append(0)
|
||||
end.append(shape[index])
|
||||
strides.append(1)
|
||||
index += 1
|
||||
return begin, end, strides
|
||||
|
||||
|
||||
def ellipsis2slice(input_, shape):
|
||||
"""Converts ellipsis to slice."""
|
||||
input_slice = input_
|
||||
|
@ -358,30 +326,24 @@ def ellipsis2slice(input_, shape):
|
|||
|
||||
|
||||
@constexpr
|
||||
def slice2indices(input_slices, shape):
|
||||
def slice2indices(input_slice, shape):
|
||||
"""
|
||||
Converts slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
||||
shape (tuple): The shape of a tensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is (n, 1).
|
||||
"""
|
||||
begin, end, strides = slice_expand(input_slices, shape)
|
||||
np_r = []
|
||||
for i, element in enumerate(shape):
|
||||
s = normalize_start(begin[i], element)
|
||||
e = normalize_stop(end[i], element)
|
||||
if s >= e:
|
||||
return False
|
||||
np_r.append(np.r_[s:e:strides[i]])
|
||||
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
||||
np_ix = np.ix_(*np_r)
|
||||
ravel = np.ravel_multi_index(np_ix, shape)
|
||||
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
||||
return ravel
|
||||
start, stop, step = normalize_slice(input_slice, shape[0])
|
||||
if check_slice_empty(start, stop, step):
|
||||
return False
|
||||
grids = ([np.array(list(range(start, stop, step)), dtype=np.int64)] +
|
||||
[np.array(list(range(dim_size)), dtype=np.int64) for dim_size in shape[1:]])
|
||||
mesh = np.ix_(*grids)
|
||||
return Tensor(np.stack(np.broadcast_arrays(*mesh), axis=-1))
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -406,29 +368,6 @@ def check_indices_value_size(indices_size, value_size):
|
|||
return value_size
|
||||
|
||||
|
||||
@constexpr
|
||||
def integer_to_indices(index, shape):
|
||||
"""Converts int or tuple[int] to indices."""
|
||||
size = reduce(lambda x, y: x * y, shape)
|
||||
range_ = np.arange(size).reshape(shape)
|
||||
value = range_[index]
|
||||
value = value.reshape(-1, 1)
|
||||
return Tensor(value, dtype=mstype.int32)
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_element_is_int(indexes):
|
||||
"""Judges tuple element type."""
|
||||
if not indexes:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexes, tuple):
|
||||
for _, ele in enumerate(indexes):
|
||||
if not isinstance(ele, int):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_index_int_cnt(types, op_name):
|
||||
"""count the int type of types which contains the tuple elements' type."""
|
||||
|
@ -439,12 +378,9 @@ def tuple_index_int_cnt(types, op_name):
|
|||
@constexpr
|
||||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
||||
basic_cnt = sum(isinstance(
|
||||
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
if tensor_cnt == len(types):
|
||||
if all(isinstance(ele, mstype.tensor_type) for ele in types):
|
||||
return ALL_TENSOR
|
||||
if basic_cnt == len(types):
|
||||
if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types):
|
||||
return ALL_BASIC
|
||||
return MIXED
|
||||
|
||||
|
@ -501,19 +437,10 @@ def generate_broadcast_shape(shapes, op_name):
|
|||
|
||||
@constexpr
|
||||
def check_two_shapes_need_broadcast(shape_x, shape_y):
|
||||
"""Check two shapes need broadcast."""
|
||||
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
|
||||
f"{shape_y} could not broadcast the required updates shape {shape_x}.")
|
||||
if len(shape_y) > len(shape_x):
|
||||
raise error
|
||||
for i in range(-len(shape_y), 0):
|
||||
if shape_y[i] > shape_x[i]:
|
||||
raise error
|
||||
if shape_y[i] < shape_x[i] and shape_y[i] != 1:
|
||||
raise error
|
||||
if shape_y == shape_x:
|
||||
return False
|
||||
return True
|
||||
"""Check shape_y needs to be broadcast to shape_x."""
|
||||
if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))):
|
||||
raise ValueError(f"{shape_y} could not broadcast with {shape_x}.")
|
||||
return shape_y != shape_x
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -523,53 +450,12 @@ def compute_multiples(origin_shape, broadcast_shape):
|
|||
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
||||
|
||||
|
||||
@constexpr
|
||||
def compute_new_shape(origin_shape, indexes_shapes_info):
|
||||
"""Compute new shape between origin shape with final shape."""
|
||||
new_shape = []
|
||||
for i in indexes_shapes_info:
|
||||
if i == origin_shape:
|
||||
new_shape.extend(origin_shape)
|
||||
else:
|
||||
new_shape.append(1)
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_int_to_slice(tuple_index):
|
||||
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
|
||||
return tuple_index_new
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_and_transform_int_index(index, shape, op_name):
|
||||
if index < -shape or index >= shape:
|
||||
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
|
||||
f"the corresponding dim length, but get {index}.")
|
||||
if index < 0:
|
||||
index += shape
|
||||
return index
|
||||
|
||||
|
||||
@constexpr
|
||||
def transform_sequence_index(sequence_index, shape, op_name):
|
||||
"""transform list or tuple with integer and boolean to tuple with integer index"""
|
||||
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
|
||||
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
|
||||
if int_count == 0 and bool_count != 0:
|
||||
if bool_count == shape:
|
||||
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
|
||||
else:
|
||||
raise IndexError("The boolean array should have the same length with the corresponding dimension")
|
||||
else:
|
||||
list_index = [int(index) for index in sequence_index]
|
||||
|
||||
for i, index in enumerate(list_index):
|
||||
list_index[i] = check_and_transform_int_index(index, shape, op_name)
|
||||
sub_tuple_index = tuple(list_index)
|
||||
return sub_tuple_index
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
|
||||
"""Convert a slice to a tensor."""
|
||||
|
@ -585,16 +471,6 @@ def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slic
|
|||
return slice_index_tensor
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_shapes_same(value_shapes, op_name):
|
||||
"""Check if the shapes in the tuple are consistent."""
|
||||
for i, shape in enumerate(value_shapes):
|
||||
if shape != value_shapes[0]:
|
||||
raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
|
||||
f"value tuple is not same as the first tensor shape.")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
|
||||
"""Convert a scalar to a tensor."""
|
||||
|
@ -605,18 +481,6 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|||
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
|
||||
"""Convert a tuple of scalar to a tensor."""
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
if len(value) != updates_shape[-1]:
|
||||
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
|
||||
f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
|
||||
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
|
||||
reps = compute_multiples(updates_shape[-1:], updates_shape)
|
||||
return Tensor(np.tile(array, reps))
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_updates_shape(data_shape, index_shape, op_type):
|
||||
"""Generate updates shape for 'tensor setitem'."""
|
||||
|
@ -679,9 +543,7 @@ def scalar_in_sequence(x, y):
|
|||
if y is None:
|
||||
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
|
||||
"but the sequence is not.")
|
||||
if x in y:
|
||||
return True
|
||||
return False
|
||||
return x in y
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -782,12 +644,6 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
|
|||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
|
||||
|
||||
|
||||
@constexpr
|
||||
def mstype_eq(x, y):
|
||||
"""Determine whether the input `x` equals `y`."""
|
||||
return x == y
|
||||
|
||||
|
||||
@constexpr
|
||||
def scalar_to_tensor(x):
|
||||
"""Convert a scalar to a tensor"""
|
||||
|
@ -801,11 +657,6 @@ def unpack(x):
|
|||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def slice_to_tuple(s):
|
||||
return (s.start, s.stop, s.step)
|
||||
|
||||
|
||||
@constexpr
|
||||
def normalize_start(start, dim_size):
|
||||
"""
|
||||
|
@ -833,8 +684,20 @@ def normalize_stop(stop, dim_size):
|
|||
|
||||
|
||||
@constexpr
|
||||
def is_ellipsis(x):
|
||||
return x is Ellipsis
|
||||
def normalize_slice(input_slice, dim_size):
|
||||
"""Normalizes start, stop, step in a slice."""
|
||||
start = normalize_start(input_slice.start, dim_size)
|
||||
stop = normalize_stop(input_slice.stop, dim_size)
|
||||
step = input_slice.step
|
||||
if step is None:
|
||||
step = 1
|
||||
if step >= 0:
|
||||
start = normalize_start(input_slice.start, dim_size)
|
||||
stop = normalize_stop(input_slice.stop, dim_size)
|
||||
else:
|
||||
start = normalize_stop(input_slice.start, dim_size)
|
||||
stop = normalize_start(input_slice.stop, dim_size)
|
||||
return start, stop, step
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -869,3 +732,64 @@ def sequence_mul_int(seq, number):
|
|||
def check_in_sequence(x, y):
|
||||
"""Determine whether the input `x` is in the sequence `y`."""
|
||||
return x in y
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_slice(x):
|
||||
return isinstance(x, slice)
|
||||
|
||||
|
||||
@constexpr
|
||||
def filter_expanded_dims(shape, not_expanded_dim):
|
||||
diff = len(not_expanded_dim) - len(shape)
|
||||
if diff < 0:
|
||||
raise ValueError('unable to broadcast {shape}')
|
||||
return tuple(compress(shape, not_expanded_dim[diff:]))
|
||||
|
||||
|
||||
@constexpr
|
||||
def sequence_to_index(sequence, dim_size):
|
||||
"""Transforms sequence to tensor index."""
|
||||
if not sequence:
|
||||
return False
|
||||
if all(isinstance(i, bool) for i in sequence):
|
||||
seq_size = len(sequence)
|
||||
if seq_size != dim_size:
|
||||
raise IndexError('dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
|
||||
sequence = tuple(compress(range(dim_size), sequence))
|
||||
if not sequence:
|
||||
return False
|
||||
return make_tensor(sequence, mstype.int64, None, dim_size)
|
||||
|
||||
|
||||
@constexpr
|
||||
def int_to_index(i, shape):
|
||||
"""Converts integer to tensor indices."""
|
||||
dim_size = shape[0]
|
||||
if i < -dim_size or i >= dim_size:
|
||||
raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}')
|
||||
i = i%dim_size
|
||||
if len(shape) == 1:
|
||||
return Tensor([[i]])
|
||||
grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]]
|
||||
mesh = np.ix_(*grids)
|
||||
index = np.stack(np.broadcast_arrays(*mesh), -1)
|
||||
return Tensor(np.insert(index, 0, i, -1))
|
||||
|
||||
|
||||
@constexpr
|
||||
def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim):
|
||||
"""Adds remaining dimensions not indexed to not_expanded_dim"""
|
||||
if idx_advanced != -1:
|
||||
if expand_true:
|
||||
# tensor indices generate only one dimension with size 1
|
||||
tensor_dims = (False,)
|
||||
else:
|
||||
tensor_dims = (True,)*tensor_index_ndim
|
||||
not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:]
|
||||
return not_expanded_dim + (True,)*rem_ndim
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_slice_empty(start, stop, step):
|
||||
return (start - stop)*step >= 0
|
||||
|
|
|
@ -68,7 +68,7 @@ def _equal_mstype(x, y):
|
|||
Returns:
|
||||
bool, if x == y return true, x != y return false.
|
||||
"""
|
||||
return const_utils.mstype_eq(x, y)
|
||||
return const_utils.is_same_type(x, y)
|
||||
|
||||
|
||||
@equal.register("String", "String")
|
||||
|
|
|
@ -62,3 +62,17 @@ def _logical_not_tuple(x):
|
|||
bool, Return logical not operation result of x.
|
||||
"""
|
||||
return F.bool_not(x.__bool__())
|
||||
|
||||
|
||||
@logical_not.register("List")
|
||||
def _logical_not_list(x):
|
||||
"""
|
||||
Return logical not operation result of a list object.
|
||||
|
||||
Args:
|
||||
x(List): The input tuple.
|
||||
|
||||
Returns:
|
||||
bool, Return logical not operation result of x.
|
||||
"""
|
||||
return F.bool_not(x.__bool__())
|
||||
|
|
|
@ -54,7 +54,7 @@ def _not_equal_mstype(x, y):
|
|||
Returns:
|
||||
bool, if x != y return true, x == y return false.
|
||||
"""
|
||||
return not const_utils.mstype_eq(x, y)
|
||||
return not const_utils.is_same_type(x, y)
|
||||
|
||||
|
||||
@not_equal.register("String", "String")
|
||||
|
|
|
@ -214,9 +214,6 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
if compile_utils.tuple_indices_have_false(tuple_index):
|
||||
return data
|
||||
tuple_index = compile_utils.format_tuple_indices(tuple_index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
|
||||
|
||||
|
||||
|
@ -238,9 +235,6 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
if compile_utils.tuple_indices_have_false(tuple_index):
|
||||
return data
|
||||
tuple_index = compile_utils.format_tuple_indices(tuple_index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
||||
|
||||
|
||||
|
@ -263,9 +257,6 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
if compile_utils.tuple_indices_have_false(tuple_index):
|
||||
return data
|
||||
tuple_index = compile_utils.format_tuple_indices(tuple_index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)
|
||||
|
||||
|
||||
|
@ -288,9 +279,6 @@ def _tensor_setitem_by_tuple_with_list(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
if compile_utils.tuple_indices_have_false(tuple_index):
|
||||
return data
|
||||
tuple_index = compile_utils.format_tuple_indices(tuple_index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)
|
||||
|
||||
|
||||
|
@ -587,6 +575,86 @@ def _tensor_setitem_by_ellipsis_with_tuple(data, index, value):
|
|||
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "None", "Number")
|
||||
def _tensor_setitem_by_none_with_number(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[...] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
u is a Number.
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (None): Index is ``...``.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "None", "Tensor")
|
||||
def _tensor_setitem_by_none_with_tensor(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[...] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
u is a Tensor.
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (None): Index is ``...``.
|
||||
value (Tensor): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "None", "List")
|
||||
def _tensor_setitem_by_none_with_list(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[...] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
u is a List, with all elements equal in length.
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (None): Index is ``...``.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "None", "Tuple")
|
||||
def _tensor_setitem_by_none_with_tuple(data, index, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[...] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
u is a Tuple, with all elements equal in length.
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
index (None): Index is ``...``.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "List", "Number")
|
||||
def _tensor_setitem_by_list_with_number(data, index, value):
|
||||
"""
|
||||
|
@ -608,9 +676,6 @@ def _tensor_setitem_by_list_with_number(data, index, value):
|
|||
index = compile_utils.format_list_indices(index, data.shape[0])
|
||||
if isinstance(index, Tensor):
|
||||
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
|
||||
if compile_utils.tuple_indices_have_false(index):
|
||||
return data
|
||||
index = compile_utils.format_tuple_indices(index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value)
|
||||
|
||||
|
||||
|
@ -635,9 +700,6 @@ def _tensor_setitem_by_list_with_tensor(data, index, value):
|
|||
index = compile_utils.format_list_indices(index, data.shape[0])
|
||||
if isinstance(index, Tensor):
|
||||
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value)
|
||||
if compile_utils.tuple_indices_have_false(index):
|
||||
return data
|
||||
index = compile_utils.format_tuple_indices(index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value)
|
||||
|
||||
|
||||
|
@ -662,9 +724,6 @@ def _tensor_setitem_by_list_with_tuple(data, index, value):
|
|||
index = compile_utils.format_list_indices(index, data.shape[0])
|
||||
if isinstance(index, Tensor):
|
||||
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
|
||||
if compile_utils.tuple_indices_have_false(index):
|
||||
return data
|
||||
index = compile_utils.format_tuple_indices(index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)
|
||||
|
||||
|
||||
|
@ -689,7 +748,4 @@ def _tensor_setitem_by_list_with_list(data, index, value):
|
|||
index = compile_utils.format_list_indices(index, data.shape[0])
|
||||
if isinstance(index, Tensor):
|
||||
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
|
||||
if compile_utils.tuple_indices_have_false(index):
|
||||
return data
|
||||
index = compile_utils.format_tuple_indices(index)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)
|
||||
|
|
|
@ -772,8 +772,11 @@ def test_tensor_assign_slice_value_2():
|
|||
def test_tensor_assign_exception():
|
||||
net = TensorAssignWithSlice()
|
||||
net2 = TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
net_e2 = TensorAssignWithSliceError2()
|
||||
# The test case is no longer appropriate since x[1:3:-1] = np.array(2) does
|
||||
# not incur an error in numpy, which leaves the original array unchanged after
|
||||
# the assign operation.
|
||||
# net_e1 = TensorAssignWithSliceError1()
|
||||
# net_e2 = TensorAssignWithSliceError2()
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
ck = np.arange(60).reshape(3, 4, 5)
|
||||
b = Tensor([1], dtype=mstype.float32)
|
||||
|
@ -787,8 +790,8 @@ def test_tensor_assign_exception():
|
|||
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(ValueError):
|
||||
net_e2(t, 2)
|
||||
# with pytest.raises(ValueError):
|
||||
# net_e2(t, 2)
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
# 1. A[Slice] = U, u.size is error
|
||||
|
@ -809,13 +812,13 @@ def test_tensor_assign_exception():
|
|||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb, Tck)
|
||||
# 3. A[Tuple(Slice...)] = U, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, b)
|
||||
# with pytest.raises(IndexError):
|
||||
# net_e1(Ta, b)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Number
|
||||
# 1. A[Tuple(Slice...)] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, 2)
|
||||
# with pytest.raises(IndexError):
|
||||
# net_e1(Ta, 2)
|
||||
|
||||
net = TensorAssignWithInteger()
|
||||
# Error for A[Number] = scalar/Tensor
|
||||
|
|
|
@ -195,6 +195,7 @@ def test_setitem_by_slice():
|
|||
x[5:0:3] = 5
|
||||
x[5:5:5] = 6
|
||||
x[-1:2] = 7
|
||||
x[1:0:-1] = 8
|
||||
return x
|
||||
setup_testcase(x, cases)
|
||||
|
||||
|
@ -214,3 +215,19 @@ def test_setitem_by_tuple_of_slices():
|
|||
x[1:1, 2:2] = 6
|
||||
return x
|
||||
setup_testcase(x, cases)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_setitem_dim_expand():
|
||||
x = onp.ones((2, 3, 4), dtype=onp.float32)
|
||||
def cases(x):
|
||||
x[None, True, [1, 0], (False, True, True), [2]] = 2
|
||||
x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]]
|
||||
x[0:1] = [[2, 3, 4, 5]]
|
||||
x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]]
|
||||
return x
|
||||
setup_testcase(x, cases)
|
||||
|
|
|
@ -448,8 +448,11 @@ def test_tensor_assign():
|
|||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = TensorAssignWithSlice()
|
||||
net2 = TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
net_e2 = TensorAssignWithSliceError2()
|
||||
# The test case is no longer appropriate since x[1:3:-1] = np.array(2) does
|
||||
# not incur an error in numpy, which leaves the original array unchanged after
|
||||
# the assign operation.
|
||||
# net_e1 = TensorAssignWithSliceError1()
|
||||
# net_e2 = TensorAssignWithSliceError2()
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
ck = np.arange(60).reshape(3, 4, 5)
|
||||
b = Tensor([1], dtype=mstype.float32)
|
||||
|
@ -465,8 +468,9 @@ def test_tensor_assign():
|
|||
net2(t, b, tck)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, 0 in shape
|
||||
with pytest.raises(ValueError):
|
||||
net_e2(t, Tensor(2, mstype.int32))
|
||||
|
||||
# with pytest.raises(ValueError):
|
||||
# net_e2(t, Tensor(2, mstype.int32))
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
# 1. A[Slice] = U, u.size is error
|
||||
|
@ -487,13 +491,13 @@ def test_tensor_assign():
|
|||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb, Tck)
|
||||
# 3. A[Tuple(Slice...)] = U, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, b)
|
||||
# with pytest.raises(IndexError):
|
||||
# net_e1(Ta, b)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Number
|
||||
# 1. A[Tuple(Slice...)] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, Tensor(2, mstype.int32))
|
||||
# with pytest.raises(IndexError):
|
||||
# net_e1(Ta, Tensor(2, mstype.int32))
|
||||
|
||||
net = TensorAssignWithInteger()
|
||||
# Error for A[Number] = scalar/Tensor
|
||||
|
|
Loading…
Reference in New Issue