From d6be043cb9418a26e59fdacf3d8584a445600930 Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Fri, 23 Apr 2021 16:28:30 +0800 Subject: [PATCH] fix & restructure setitem --- .../composite/multitype_ops/_compile_utils.py | 257 ++++++--------- .../multitype_ops/_constexpr_utils.py | 296 +++++++----------- .../ops/composite/multitype_ops/equal_impl.py | 2 +- .../composite/multitype_ops/logic_not_impl.py | 14 + .../composite/multitype_ops/not_equal_impl.py | 2 +- .../composite/multitype_ops/setitem_impl.py | 104 ++++-- tests/st/pynative/test_tensor_index.py | 19 +- tests/st/pynative/test_tensor_setitem.py | 17 + tests/ut/python/ops/test_tensor_slice.py | 20 +- 9 files changed, 349 insertions(+), 382 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 6c563589793..26bd52b53e3 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -58,9 +58,6 @@ def _tensor_setitem(self, index, value): if isinstance(index, Tensor): return tensor_setitem_by_tensor(self, index, value) if isinstance(index, tuple): - if tuple_indices_have_false(index): - return self - index = format_tuple_indices(index) return tensor_setitem_by_tuple(self, index, value) if isinstance(index, bool): return tensor_setitem_by_bool(self, index, value) @@ -68,7 +65,7 @@ def _tensor_setitem(self, index, value): return tensor_setitem_by_number(self, index, value) if isinstance(index, slice): return tensor_setitem_by_slice(self, index, value) - if index is ...: + if index in (None, ...): return tensor_setitem_by_ellipsis(self, index, value) raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \ @@ -142,7 +139,7 @@ tensor_operator_registry.register('__floordiv__', _tensor_floordiv) def _broadcast(broadcast_shape, x): """Broadcast tensor to the required shape.""" - if F.shape(x) == broadcast_shape: + if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)): return x multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) if multiples: @@ -242,7 +239,7 @@ def _tensor_index_by_integer(data, int_index): const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim) data_shape = F.shape(data) - transformed_number = const_utils.check_and_transform_int_index(int_index, data_shape[0], const_utils.TENSOR_GETITEM) + transformed_number = const_utils.check_range(int_index, data_shape[0]) begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number) shrink_axis_mask = 1 return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) @@ -264,10 +261,9 @@ def tensor_index_by_list(data, list_index): data_shape = F.shape(data) indexes_types = hyper_map(F.typeof, list_index) if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)): - sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) - if not sub_tuple_index: + tensor_index = const_utils.sequence_to_index(list_index, data_shape[0]) + if tensor_index is False: const_utils.raise_index_error("Getitem does not support empty list, this will reference shape '0'.") - tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64) return F.gather(data, tensor_index, 0) tuple_index_new = () @@ -341,15 +337,13 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): if i in int_positions: - int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) + int_index = const_utils.check_range(index, dim_size) tensor_index = F.scalar_to_tensor(int_index, mstype.int64) tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) tensor_positions += (i,) elif i in sequence_positions: - sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) - tensor_index = const_utils.make_tensor(sequence_index) - tensor_index = F.cast(tensor_index, mstype.int64) + tensor_index = const_utils.sequence_to_index(index, dim_size) tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) tensor_positions += (i,) @@ -398,6 +392,8 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): const_utils.check_types_valid(indexes_types, mstype.int_type, op_name) tensor_index_shape = hyper_map(F.shape, tuple_index) broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) + if len(broadcast_shape) < 2: + broadcast_shape = (1,) + broadcast_shape broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) new_broadcast_tensors = () for tensor in broadcast_tensors: @@ -417,15 +413,13 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): if i in int_positions: - int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) + int_index = const_utils.check_range(index, dim_size) tensor_index = F.scalar_to_tensor(int_index, mstype.int64) tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) tensor_positions += (i,) elif i in sequence_positions: - sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) - tensor_index = const_utils.make_tensor(sequence_index) - tensor_index = F.cast(tensor_index, mstype.int64) + tensor_index = const_utils.sequence_to_index(index, dim_size) tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) tensor_positions += (i,) @@ -435,11 +429,9 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) elif i in slice_positions: - start, stop, _ = const_utils.slice_to_tuple(index) - start = const_utils.normalize_start(start, dim_size) - stop = const_utils.normalize_stop(stop, dim_size) - if start >= stop: - return None + start, stop, step = const_utils.normalize_slice(index, dim_size) + if const_utils.check_slice_empty(start, stop, step): + return False slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) slice_shapes += (len(slice_ele_list_index),) tuple_index_new += (slice_ele_list_index,) @@ -522,6 +514,7 @@ def tensor_setitem_by_tensor(self, index, value): def tensor_setitem_by_tuple(self, index, value): if isinstance(value, (int, float, bool)): + index = format_tuple_indices(index) return tensor_setitem_by_tuple_with_number(self, index, value) if isinstance(value, Tensor): return tensor_setitem_by_tuple_with_tensor(self, index, value) @@ -622,22 +615,6 @@ def tensor_setitem_by_tensor_with_sequence(data, index, value): return _tensor_setitem_by_tensor_with_sequence(data, index, value) -def _tensor_indices_number(data, data_shape, index, indices, value): - """Assigns a scalar value to the tensor.""" - data_size = F.shape_mul(data.shape) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = const_utils.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = F.fill(data_dtype, (indices_size,), value) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) - - def _tensor_setitem_by_tensor_with_sequence(data, index, value): """Set a tensor item by a tensor with a tuple.""" updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) @@ -647,30 +624,20 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value): def tensor_setitem_by_slice_with_number(data, input_slice, value): """Givens a scalar assign to tensor by slice""" - check_result = const_utils.check_tensor_setitem_index(input_slice) - result = None - if check_result: - data_shape = F.shape(data) - indices = const_utils.slice2indices(input_slice, data_shape) - if indices is False: - return data - is_tuple_int = const_utils.tuple_element_is_int(input_slice) - if is_tuple_int: - indices = const_utils.integer_to_indices(input_slice, data_shape) - result = _tensor_indices_number(data, data_shape, input_slice, indices, value) - return result + value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value) + return tensor_setitem_by_slice_with_tensor(data, input_slice, value) def tensor_setitem_by_tuple_with_number(data, tuple_index, value): """Assigns the tensor by tuple with number value.""" - tuple_index = ignore_dim_expand(tuple_index) + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, const_utils.TENSOR_GETITEM) + tuple_index, _ = remove_expanded_dims(tuple_index, F.shape(data)) + if tuple_index is False: + return data if len(tuple_index) == 1: data[tuple_index[0]] = value return data - op_name = const_utils.TENSOR_GETITEM - tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) @@ -682,37 +649,12 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): if int_cnt == const_utils.ALL_INT: tuple_index = const_utils.convert_int_to_slice(tuple_index) indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) - if indices is None: + if indices is False: return data updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) -def _tensor_indices_tensor(data, data_shape, index, indices, value): - """Assigns a tensor value to the tensor.""" - data_size = F.shape_mul(data.shape) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = const_utils.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = None - value_size = value.size - - value_size = const_utils.check_indices_value_size(indices_size, value_size) - if value_size == 1: - value_fill = F.fill(data_dtype, (indices_size,), 1) - value = F.cast(value, data_dtype) - value_fill = F.tensor_mul(value_fill, value) - elif value_size > 1: - value_fill = F.reshape(value, (indices_size,)) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u.astype(data_dtype), data) - - def tensor_setitem_by_slice_with_tensor(data, input_slice, value): """Assigns a tensor value to the tensor by slice.""" result = None @@ -722,10 +664,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): indices = const_utils.slice2indices(input_slice, data_shape) if indices is False: return data - is_tuple_int = const_utils.tuple_element_is_int(input_slice) - if is_tuple_int: - indices = const_utils.integer_to_indices(input_slice, data_shape) - result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) + value_shape = const_utils.tuple_slice(F.shape(indices), None, -1) + value = _broadcast(value_shape, value) + result = P.TensorScatterUpdate()(data, indices, value.astype(F.dtype(data))) return result @@ -737,16 +678,17 @@ def tensor_setitem_by_slice_with_sequence(data, input_slice, value): def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): """Assigns the tensor by tuple with tensor value.""" - value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data)) + op_name = const_utils.TENSOR_SETITEM + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) + tuple_index, not_expanded_dim = remove_expanded_dims(tuple_index, F.shape(data)) + if tuple_index is False: + return data + value_shape = const_utils.filter_expanded_dims(F.shape(value), not_expanded_dim) value = F.reshape(value, value_shape) - tuple_index = ignore_dim_expand(tuple_index) if len(tuple_index) == 1: data[tuple_index[0]] = value return data - op_name = const_utils.TENSOR_GETITEM - tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) @@ -763,7 +705,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): new_shape += value.shape value = F.reshape(value, new_shape) indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) - if indices is None: + if indices is False: return data updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) @@ -776,9 +718,8 @@ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value): def tensor_setitem_by_number_with_number(data, index, value): """Assigns the tensor by number with number value.""" - data_shape = F.shape(data) - indices = const_utils.integer_to_indices(index, data_shape) - return _tensor_indices_number(data, data_shape, index, indices, value) + value = F.fill(F.dtype(data), const_utils.tuple_slice(F.shape(data), 1, None), value) + return tensor_setitem_by_number_with_tensor(data, index, value) def tensor_setitem_by_number_with_sequence(data, index, value): @@ -790,8 +731,10 @@ def tensor_setitem_by_number_with_sequence(data, index, value): def tensor_setitem_by_number_with_tensor(data, index, value): """Assigns the tensor by number with tensor value.""" data_shape = F.shape(data) - indices = const_utils.integer_to_indices(index, data_shape) - return _tensor_indices_tensor(data, data_shape, index, indices, value) + index = const_utils.int_to_index(index, data_shape) + value_shape = const_utils.tuple_slice(F.shape(index), None, -1) + value = _broadcast(value_shape, value) + return P.TensorScatterUpdate()(data, index, value) def tensor_setitem_by_ellipsis_with_number(data, value): @@ -825,8 +768,10 @@ def tensor_setitem_by_bool(data, index, value): data_shape = F.shape(data) if not index: data_shape = (0,) + data_shape - if not isinstance(value, Tensor): + if isinstance(value, (list, tuple)): value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) + elif isinstance(value, (int, float, bool)): + value = const_utils.make_tensor(value) value_shape = F.shape(value) source_shape = const_utils.get_source_shape(data_shape, value_shape) if index: @@ -851,8 +796,7 @@ def format_list_indices(list_indices, length): # If eyery element in list is bool, it's treated as 1-D bool tensor. # If every element in list is int(not all bool), it's treated as int tensor. if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)): - list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM) - return const_utils.make_tensor(list_indices) + return const_utils.sequence_to_index(list_indices, length) # If list contains other types(.../list/tuple/None), it's treated as a tuple return const_utils.deep_tuple(list_indices) @@ -871,64 +815,69 @@ def format_tuple_indices(tuple_indices): return res -def tuple_indices_have_false(tuple_indices): - """Returns True if tuple_indices contains False.""" - for i in tuple_indices: - if i is False: - return True - return False +def remove_expanded_dims(tuple_index, data_shape): + """Removes expanded dimensions in tuple_index and value.""" + op_name = const_utils.TENSOR_SETITEM + not_expanded_dim = () + shapes = () + has_true = False + has_false = False + has_sequence = False + indices_out = () # with dimension expansion indices removed + idx_tensor = -1 # index of the previous tensor + idx_advanced = -1 # index of the first advanced index in expanded tensor + cur_dim = 0 # current dimension of the data to be indexed + for i, v in enumerate(tuple_index): + index_out = format_index(v, data_shape, cur_dim) -def ignore_dim_expand(idx): - """Filters flags for dimension expansion from idx.""" - res = () - for i in idx: - if not i is True and not i is None: - res += (i,) - if not res: - res = (True,) - return res - - -def remove_ignored_dim(idx, value_shape, data_rank): - """Removes dimensions in value that correspond to dimension expansion flags in index.""" - has_ellipsis = False - has_leading_true = False - has_trailing_true = False - cnt_leading_expanded = 0 - cnt_trailing_expanded = 0 - cnt_not_dim_expand = 0 - for i in idx: - if i is True: - if has_ellipsis: - has_trailing_true = True - else: - has_leading_true = True - elif i is None: - if has_ellipsis: - cnt_trailing_expanded += 1 - else: - cnt_leading_expanded += 1 + if index_out is None: + not_expanded_dim += (False,) + elif const_utils.is_slice(index_out): + indices_out += (index_out,) + not_expanded_dim += (True,) + start, stop, step = const_utils.normalize_slice(index_out, data_shape[cur_dim]) + if const_utils.check_slice_empty(start, stop, step): + has_false = True + cur_dim += 1 + elif isinstance(index_out, (Tensor, bool)): # advanced index + if idx_advanced == -1: + idx_advanced = len(not_expanded_dim) + elif i - idx_tensor > 1: + idx_advanced = 0 + idx_tensor = i + if isinstance(index_out, Tensor): + if F.rank(index_out) > 0: + has_sequence = True + indices_out += (index_out,) + shapes += (F.shape(index_out),) + cur_dim += 1 + has_true = has_true or index_out is True + has_false = has_false or index_out is False else: - if const_utils.is_ellipsis(i): - has_ellipsis = True - cnt_not_dim_expand += 1 - if cnt_not_dim_expand + 1 < data_rank: - if has_leading_true: - cnt_leading_expanded += 1 - elif has_trailing_true: - cnt_trailing_expanded += 1 + const_utils.raise_index_error('invalid index type') - value_starting_pos = 0 - while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1: - value_starting_pos += 1 - cnt_leading_expanded -= 1 + broadcast_shape = const_utils.generate_broadcast_shape(shapes, op_name) + if has_false: + if F.shape_mul(broadcast_shape) != 1: + const_utils.raise_index_error('unable to broadcast indices') + return False, not_expanded_dim - value_expanded_pos = len(value_shape) - cnt_trailing_expanded - value_expanded_not_unit = False - for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None): - if i != 1: - value_expanded_not_unit = True - if value_expanded_pos < 0 or value_expanded_not_unit: - const_utils.raise_value_error('shape mismatch') - return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos) + expand_true = has_true and not(has_false or has_sequence) # whether to expand dimension at True + tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices + rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed + not_expanded_dim = const_utils.rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, + rem_ndim, not_expanded_dim) + + if not indices_out: + indices_out = (True,) + return indices_out, not_expanded_dim + + +def format_index(idx, data_shape, cur_dim): + """Converts advanced index into tensor.""" + if isinstance(idx, (tuple, list)): + idx = const_utils.sequence_to_index(idx, data_shape[cur_dim]) + elif isinstance(idx, int) and not isinstance(idx, bool): + idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim]) + return idx diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 7c43115ea41..7a785cfb8e8 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -14,7 +14,7 @@ # ============================================================================ """constexpr util""" -from functools import reduce +from itertools import compress import numpy as np @@ -75,10 +75,12 @@ def make_empty_slice(): @constexpr -def _deep_list(array_like): +def _deep_list(array_like, ndim=-1): """convert nested tuple/list mixtures to pure nested list""" + if ndim != -1: + check_range(array_like, ndim) if isinstance(array_like, (list, tuple)): - return list(map(_deep_list, array_like)) + return list(map(lambda x: _deep_list(x, ndim), array_like)) return array_like @@ -115,7 +117,16 @@ def _deep_tensor_to_nparray(array_like): @constexpr -def make_tensor(a, dtype=mstype.int32, data_shape=None): +def check_range(x, ndim): + if isinstance(x, int) and not isinstance(x, bool): + if x >= ndim or x < -ndim: + raise IndexError(f'index {x} if out of bounds for dimension with size {ndim}') + x = x%ndim + return x + + +@constexpr +def make_tensor(a, dtype=mstype.int64, data_shape=None, ndim=-1): """ Converts the input to tensor. @@ -133,16 +144,18 @@ def make_tensor(a, dtype=mstype.int32, data_shape=None): TypeError: If input arguments have types not specified above. ValueError: If input `a` has different sizes at different dimensions. """ - if data_shape: return Tensor(np.zeros(data_shape), dtype) if not isinstance(a, (list, tuple, int, float, bool)): raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`") + if ndim != -1: + check_range(a, ndim) + if isinstance(a, (list, tuple)): # Convert all tuple/nested tuples to lists - a = _deep_list(a) + a = _deep_list(a, ndim) # Convert all tensor sub-elements to numpy arrays a = _deep_tensor_to_nparray(a) a = np.asarray(a) @@ -292,51 +305,6 @@ def get_pos_of_indexes_types(indexes_types, op_name): tensor_positions, sequence_positions -def slice_expand(input_slices, shape): - """ - Converts slice to indices. - - Inputs: - slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. - shape (tuple): The shape of a sensor is an integer element tuple. - - Outputs: - tuple[list], This is expressed as (begins, ends, strides). - """ - begin, end, strides = [], [], [] - index = 0 - slices = None - # Slice or tuple(Slice...) - if isinstance(input_slices, slice): - slices = (input_slices,) - elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))): - is_have_ellipsis = False - for _, element in enumerate(input_slices): - if isinstance(element, type(...)): - is_have_ellipsis = True - break - if is_have_ellipsis: - slices = ellipsis2slice(input_slices, shape) - else: - slices = input_slices - else: - raise IndexError("Tensor's index type is not supported yet.") - for s in slices: - start = 0 if (s.start is None) else s.start - stop = shape[index] if (s.stop is None) else s.stop - step = 1 if (s.step is None) else s.step - begin.append(start) - end.append(stop) - strides.append(step) - index += 1 - while index < len(shape): - begin.append(0) - end.append(shape[index]) - strides.append(1) - index += 1 - return begin, end, strides - - def ellipsis2slice(input_, shape): """Converts ellipsis to slice.""" input_slice = input_ @@ -358,30 +326,24 @@ def ellipsis2slice(input_, shape): @constexpr -def slice2indices(input_slices, shape): +def slice2indices(input_slice, shape): """ Converts slice to indices. Inputs: - slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. + input_slice (Union[Slice, tuple[Slice]]): Slice tuple or slice. shape (tuple): The shape of a tensor is an integer element tuple. Outputs: Tensor, the shape is (n, 1). """ - begin, end, strides = slice_expand(input_slices, shape) - np_r = [] - for i, element in enumerate(shape): - s = normalize_start(begin[i], element) - e = normalize_stop(end[i], element) - if s >= e: - return False - np_r.append(np.r_[s:e:strides[i]]) - # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) - np_ix = np.ix_(*np_r) - ravel = np.ravel_multi_index(np_ix, shape) - ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) - return ravel + start, stop, step = normalize_slice(input_slice, shape[0]) + if check_slice_empty(start, stop, step): + return False + grids = ([np.array(list(range(start, stop, step)), dtype=np.int64)] + + [np.array(list(range(dim_size)), dtype=np.int64) for dim_size in shape[1:]]) + mesh = np.ix_(*grids) + return Tensor(np.stack(np.broadcast_arrays(*mesh), axis=-1)) @constexpr @@ -406,29 +368,6 @@ def check_indices_value_size(indices_size, value_size): return value_size -@constexpr -def integer_to_indices(index, shape): - """Converts int or tuple[int] to indices.""" - size = reduce(lambda x, y: x * y, shape) - range_ = np.arange(size).reshape(shape) - value = range_[index] - value = value.reshape(-1, 1) - return Tensor(value, dtype=mstype.int32) - - -@constexpr -def tuple_element_is_int(indexes): - """Judges tuple element type.""" - if not indexes: - raise IndexError("Tensor's index cannot be empty.") - if isinstance(indexes, tuple): - for _, ele in enumerate(indexes): - if not isinstance(ele, int): - return False - return True - return False - - @constexpr def tuple_index_int_cnt(types, op_name): """count the int type of types which contains the tuple elements' type.""" @@ -439,12 +378,9 @@ def tuple_index_int_cnt(types, op_name): @constexpr def tuple_index_type_cnt(types, op_name): """count the tensor type of types which contains the tuple elements' type.""" - tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) - basic_cnt = sum(isinstance( - ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) - if tensor_cnt == len(types): + if all(isinstance(ele, mstype.tensor_type) for ele in types): return ALL_TENSOR - if basic_cnt == len(types): + if all(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types): return ALL_BASIC return MIXED @@ -501,19 +437,10 @@ def generate_broadcast_shape(shapes, op_name): @constexpr def check_two_shapes_need_broadcast(shape_x, shape_y): - """Check two shapes need broadcast.""" - error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " - f"{shape_y} could not broadcast the required updates shape {shape_x}.") - if len(shape_y) > len(shape_x): - raise error - for i in range(-len(shape_y), 0): - if shape_y[i] > shape_x[i]: - raise error - if shape_y[i] < shape_x[i] and shape_y[i] != 1: - raise error - if shape_y == shape_x: - return False - return True + """Check shape_y needs to be broadcast to shape_x.""" + if any(j not in (i, 1) for i, j in zip(reversed(shape_x), reversed(shape_y))): + raise ValueError(f"{shape_y} could not broadcast with {shape_x}.") + return shape_y != shape_x @constexpr @@ -523,53 +450,12 @@ def compute_multiples(origin_shape, broadcast_shape): return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) -@constexpr -def compute_new_shape(origin_shape, indexes_shapes_info): - """Compute new shape between origin shape with final shape.""" - new_shape = [] - for i in indexes_shapes_info: - if i == origin_shape: - new_shape.extend(origin_shape) - else: - new_shape.append(1) - return tuple(new_shape) - - @constexpr def convert_int_to_slice(tuple_index): tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) return tuple_index_new -@constexpr -def check_and_transform_int_index(index, shape, op_name): - if index < -shape or index >= shape: - raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " - f"the corresponding dim length, but get {index}.") - if index < 0: - index += shape - return index - - -@constexpr -def transform_sequence_index(sequence_index, shape, op_name): - """transform list or tuple with integer and boolean to tuple with integer index""" - bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) - int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count - if int_count == 0 and bool_count != 0: - if bool_count == shape: - list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) - else: - raise IndexError("The boolean array should have the same length with the corresponding dimension") - else: - list_index = [int(index) for index in sequence_index] - - for i, index in enumerate(list_index): - list_index[i] = check_and_transform_int_index(index, shape, op_name) - sub_tuple_index = tuple(list_index) - return sub_tuple_index - - @constexpr def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position): """Convert a slice to a tensor.""" @@ -585,16 +471,6 @@ def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slic return slice_index_tensor -@constexpr -def check_shapes_same(value_shapes, op_name): - """Check if the shapes in the tuple are consistent.""" - for i, shape in enumerate(value_shapes): - if shape != value_shapes[0]: - raise ValueError(f"For '{op_name}', the {i}th tensor shape in " - f"value tuple is not same as the first tensor shape.") - return True - - @constexpr def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): """Convert a scalar to a tensor.""" @@ -605,18 +481,6 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty return Tensor(np.full(updates_shape, value), dtype=data_dtype) -@constexpr -def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): - """Convert a tuple of scalar to a tensor.""" - updates_shape = generate_updates_shape(data_shape, index_shape, op_type) - if len(value) != updates_shape[-1]: - raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} " - f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.") - array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) - reps = compute_multiples(updates_shape[-1:], updates_shape) - return Tensor(np.tile(array, reps)) - - @constexpr def generate_updates_shape(data_shape, index_shape, op_type): """Generate updates shape for 'tensor setitem'.""" @@ -679,9 +543,7 @@ def scalar_in_sequence(x, y): if y is None: raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " "but the sequence is not.") - if x in y: - return True - return False + return x in y @constexpr @@ -782,12 +644,6 @@ def get_stride_info_from_tuple(data_shape, tuple_index): return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis -@constexpr -def mstype_eq(x, y): - """Determine whether the input `x` equals `y`.""" - return x == y - - @constexpr def scalar_to_tensor(x): """Convert a scalar to a tensor""" @@ -801,11 +657,6 @@ def unpack(x): return x -@constexpr -def slice_to_tuple(s): - return (s.start, s.stop, s.step) - - @constexpr def normalize_start(start, dim_size): """ @@ -833,8 +684,20 @@ def normalize_stop(stop, dim_size): @constexpr -def is_ellipsis(x): - return x is Ellipsis +def normalize_slice(input_slice, dim_size): + """Normalizes start, stop, step in a slice.""" + start = normalize_start(input_slice.start, dim_size) + stop = normalize_stop(input_slice.stop, dim_size) + step = input_slice.step + if step is None: + step = 1 + if step >= 0: + start = normalize_start(input_slice.start, dim_size) + stop = normalize_stop(input_slice.stop, dim_size) + else: + start = normalize_stop(input_slice.start, dim_size) + stop = normalize_start(input_slice.stop, dim_size) + return start, stop, step @constexpr @@ -869,3 +732,64 @@ def sequence_mul_int(seq, number): def check_in_sequence(x, y): """Determine whether the input `x` is in the sequence `y`.""" return x in y + + +@constexpr +def is_slice(x): + return isinstance(x, slice) + + +@constexpr +def filter_expanded_dims(shape, not_expanded_dim): + diff = len(not_expanded_dim) - len(shape) + if diff < 0: + raise ValueError('unable to broadcast {shape}') + return tuple(compress(shape, not_expanded_dim[diff:])) + + +@constexpr +def sequence_to_index(sequence, dim_size): + """Transforms sequence to tensor index.""" + if not sequence: + return False + if all(isinstance(i, bool) for i in sequence): + seq_size = len(sequence) + if seq_size != dim_size: + raise IndexError('dimension is {dim_size} but corresponding boolean dimension is {seq_size}') + sequence = tuple(compress(range(dim_size), sequence)) + if not sequence: + return False + return make_tensor(sequence, mstype.int64, None, dim_size) + + +@constexpr +def int_to_index(i, shape): + """Converts integer to tensor indices.""" + dim_size = shape[0] + if i < -dim_size or i >= dim_size: + raise IndexError(f'index {i} is out of bounds for axis 0 with size {dim_size}') + i = i%dim_size + if len(shape) == 1: + return Tensor([[i]]) + grids = [np.array(list(range(size)), dtype=np.int64) for size in shape[1:]] + mesh = np.ix_(*grids) + index = np.stack(np.broadcast_arrays(*mesh), -1) + return Tensor(np.insert(index, 0, i, -1)) + + +@constexpr +def rem_not_expanded_dims(idx_advanced, expand_true, tensor_index_ndim, rem_ndim, not_expanded_dim): + """Adds remaining dimensions not indexed to not_expanded_dim""" + if idx_advanced != -1: + if expand_true: + # tensor indices generate only one dimension with size 1 + tensor_dims = (False,) + else: + tensor_dims = (True,)*tensor_index_ndim + not_expanded_dim = not_expanded_dim[:idx_advanced] + tensor_dims + not_expanded_dim[idx_advanced:] + return not_expanded_dim + (True,)*rem_ndim + + +@constexpr +def check_slice_empty(start, stop, step): + return (start - stop)*step >= 0 diff --git a/mindspore/ops/composite/multitype_ops/equal_impl.py b/mindspore/ops/composite/multitype_ops/equal_impl.py index 94b5e680cbf..4358d73ba8b 100644 --- a/mindspore/ops/composite/multitype_ops/equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/equal_impl.py @@ -68,7 +68,7 @@ def _equal_mstype(x, y): Returns: bool, if x == y return true, x != y return false. """ - return const_utils.mstype_eq(x, y) + return const_utils.is_same_type(x, y) @equal.register("String", "String") diff --git a/mindspore/ops/composite/multitype_ops/logic_not_impl.py b/mindspore/ops/composite/multitype_ops/logic_not_impl.py index 9e8410d9250..8cf26b50036 100644 --- a/mindspore/ops/composite/multitype_ops/logic_not_impl.py +++ b/mindspore/ops/composite/multitype_ops/logic_not_impl.py @@ -62,3 +62,17 @@ def _logical_not_tuple(x): bool, Return logical not operation result of x. """ return F.bool_not(x.__bool__()) + + +@logical_not.register("List") +def _logical_not_list(x): + """ + Return logical not operation result of a list object. + + Args: + x(List): The input tuple. + + Returns: + bool, Return logical not operation result of x. + """ + return F.bool_not(x.__bool__()) diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index 54e74b6b9f4..2a278d5c746 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -54,7 +54,7 @@ def _not_equal_mstype(x, y): Returns: bool, if x != y return true, x == y return false. """ - return not const_utils.mstype_eq(x, y) + return not const_utils.is_same_type(x, y) @not_equal.register("String", "String") diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 5cdfefd8c62..bee58dd3e18 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -214,9 +214,6 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - if compile_utils.tuple_indices_have_false(tuple_index): - return data - tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value) @@ -238,9 +235,6 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - if compile_utils.tuple_indices_have_false(tuple_index): - return data - tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) @@ -263,9 +257,6 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - if compile_utils.tuple_indices_have_false(tuple_index): - return data - tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) @@ -288,9 +279,6 @@ def _tensor_setitem_by_tuple_with_list(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - if compile_utils.tuple_indices_have_false(tuple_index): - return data - tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) @@ -587,6 +575,86 @@ def _tensor_setitem_by_ellipsis_with_tuple(data, index, value): return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) +@setitem.register("Tensor", "None", "Number") +def _tensor_setitem_by_none_with_number(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Number. + Inputs: + data (Tensor): Assigned tensor. + index (None): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value) + + +@setitem.register("Tensor", "None", "Tensor") +def _tensor_setitem_by_none_with_tensor(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Tensor. + Inputs: + data (Tensor): Assigned tensor. + index (None): Index is ``...``. + value (Tensor): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value) + + +@setitem.register("Tensor", "None", "List") +def _tensor_setitem_by_none_with_list(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a List, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (None): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) + + +@setitem.register("Tensor", "None", "Tuple") +def _tensor_setitem_by_none_with_tuple(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Tuple, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (None): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) + + @setitem.register("Tensor", "List", "Number") def _tensor_setitem_by_list_with_number(data, index, value): """ @@ -608,9 +676,6 @@ def _tensor_setitem_by_list_with_number(data, index, value): index = compile_utils.format_list_indices(index, data.shape[0]) if isinstance(index, Tensor): return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value) - if compile_utils.tuple_indices_have_false(index): - return data - index = compile_utils.format_tuple_indices(index) return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value) @@ -635,9 +700,6 @@ def _tensor_setitem_by_list_with_tensor(data, index, value): index = compile_utils.format_list_indices(index, data.shape[0]) if isinstance(index, Tensor): return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value) - if compile_utils.tuple_indices_have_false(index): - return data - index = compile_utils.format_tuple_indices(index) return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value) @@ -662,9 +724,6 @@ def _tensor_setitem_by_list_with_tuple(data, index, value): index = compile_utils.format_list_indices(index, data.shape[0]) if isinstance(index, Tensor): return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) - if compile_utils.tuple_indices_have_false(index): - return data - index = compile_utils.format_tuple_indices(index) return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) @@ -689,7 +748,4 @@ def _tensor_setitem_by_list_with_list(data, index, value): index = compile_utils.format_list_indices(index, data.shape[0]) if isinstance(index, Tensor): return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) - if compile_utils.tuple_indices_have_false(index): - return data - index = compile_utils.format_tuple_indices(index) return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_index.py index ac30b6ae15d..f9a4cd5b2d9 100644 --- a/tests/st/pynative/test_tensor_index.py +++ b/tests/st/pynative/test_tensor_index.py @@ -772,8 +772,11 @@ def test_tensor_assign_slice_value_2(): def test_tensor_assign_exception(): net = TensorAssignWithSlice() net2 = TensorAssignWithSlice2() - net_e1 = TensorAssignWithSliceError1() - net_e2 = TensorAssignWithSliceError2() + # The test case is no longer appropriate since x[1:3:-1] = np.array(2) does + # not incur an error in numpy, which leaves the original array unchanged after + # the assign operation. + # net_e1 = TensorAssignWithSliceError1() + # net_e2 = TensorAssignWithSliceError2() a = np.arange(60).reshape(3, 4, 5) ck = np.arange(60).reshape(3, 4, 5) b = Tensor([1], dtype=mstype.float32) @@ -787,8 +790,8 @@ def test_tensor_assign_exception(): tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) # Error for A[Slice] = Number # 1. A[Slice] = Number, Slice error - with pytest.raises(ValueError): - net_e2(t, 2) + # with pytest.raises(ValueError): + # net_e2(t, 2) # Error for A[Slice] = U, U is a Tensor # 1. A[Slice] = U, u.size is error @@ -809,13 +812,13 @@ def test_tensor_assign_exception(): with pytest.raises(ValueError): net(Ta, Tb, Tck) # 3. A[Tuple(Slice...)] = U, Slice error - with pytest.raises(IndexError): - net_e1(Ta, b) + # with pytest.raises(IndexError): + # net_e1(Ta, b) # Error for A[Tuple(Slice...)] = Number # 1. A[Tuple(Slice...)] = Number, Slice error - with pytest.raises(IndexError): - net_e1(Ta, 2) + # with pytest.raises(IndexError): + # net_e1(Ta, 2) net = TensorAssignWithInteger() # Error for A[Number] = scalar/Tensor diff --git a/tests/st/pynative/test_tensor_setitem.py b/tests/st/pynative/test_tensor_setitem.py index 7d5a9cc0832..aa031be79c4 100644 --- a/tests/st/pynative/test_tensor_setitem.py +++ b/tests/st/pynative/test_tensor_setitem.py @@ -195,6 +195,7 @@ def test_setitem_by_slice(): x[5:0:3] = 5 x[5:5:5] = 6 x[-1:2] = 7 + x[1:0:-1] = 8 return x setup_testcase(x, cases) @@ -214,3 +215,19 @@ def test_setitem_by_tuple_of_slices(): x[1:1, 2:2] = 6 return x setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_dim_expand(): + x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): + x[None, True, [1, 0], (False, True, True), [2]] = 2 + x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]] + x[0:1] = [[2, 3, 4, 5]] + x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]] + return x + setup_testcase(x, cases) diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index cecaf659aa1..5d90331e358 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -448,8 +448,11 @@ def test_tensor_assign(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() net2 = TensorAssignWithSlice2() - net_e1 = TensorAssignWithSliceError1() - net_e2 = TensorAssignWithSliceError2() + # The test case is no longer appropriate since x[1:3:-1] = np.array(2) does + # not incur an error in numpy, which leaves the original array unchanged after + # the assign operation. + # net_e1 = TensorAssignWithSliceError1() + # net_e2 = TensorAssignWithSliceError2() a = np.arange(60).reshape(3, 4, 5) ck = np.arange(60).reshape(3, 4, 5) b = Tensor([1], dtype=mstype.float32) @@ -465,8 +468,9 @@ def test_tensor_assign(): net2(t, b, tck) # Error for A[Slice] = Number # 1. A[Slice] = Number, 0 in shape - with pytest.raises(ValueError): - net_e2(t, Tensor(2, mstype.int32)) + + # with pytest.raises(ValueError): + # net_e2(t, Tensor(2, mstype.int32)) # Error for A[Slice] = U, U is a Tensor # 1. A[Slice] = U, u.size is error @@ -487,13 +491,13 @@ def test_tensor_assign(): with pytest.raises(ValueError): net(Ta, Tb, Tck) # 3. A[Tuple(Slice...)] = U, Slice error - with pytest.raises(IndexError): - net_e1(Ta, b) + # with pytest.raises(IndexError): + # net_e1(Ta, b) # Error for A[Tuple(Slice...)] = Number # 1. A[Tuple(Slice...)] = Number, Slice error - with pytest.raises(IndexError): - net_e1(Ta, Tensor(2, mstype.int32)) + # with pytest.raises(IndexError): + # net_e1(Ta, Tensor(2, mstype.int32)) net = TensorAssignWithInteger() # Error for A[Number] = scalar/Tensor