From 79058d35099a480480b7fe4979a677c8c34dfdd7 Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Fri, 12 Jun 2020 16:08:54 +0800 Subject: [PATCH] add support for parameter support for tensor setitem add support for tensor assgin --- mindspore/ccsrc/ir/tensor.cc | 17 + mindspore/ccsrc/ir/tensor.h | 3 + mindspore/common/parameter.py | 2 + mindspore/common/tensor.py | 2 + .../composite/multitype_ops/_compile_utils.py | 387 ++++++++++++++-- .../composite/multitype_ops/setitem_impl.py | 228 +--------- tests/st/pynative/test_tensor_index.py | 420 +++++++++++++----- 7 files changed, 699 insertions(+), 360 deletions(-) diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc index 4e2e996bacf..55d686062e7 100644 --- a/mindspore/ccsrc/ir/tensor.cc +++ b/mindspore/ccsrc/ir/tensor.cc @@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) { } return *this; } +Tensor &Tensor::AssignValue(const Tensor &tensor) { + *this = tensor; + return *this; +} bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); @@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.set_dtype(mindspore.int32) mindspore.int32 )mydelimiter") + .def("assign_value", &Tensor::AssignValue, R"mydelimiter( + Assign another tensor value to this. + + Arg: + value (:class:`mindspore.tensor`): The value tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) + >>> data.assign_value(data2) + >>> data.shape + (2, 2) + )mydelimiter") .def("__str__", &Tensor::ToString) .def("__repr__", &Tensor::ToStringRepr) .def(py::pickle( diff --git a/mindspore/ccsrc/ir/tensor.h b/mindspore/ccsrc/ir/tensor.h index 700dcd49102..48d04b92022 100644 --- a/mindspore/ccsrc/ir/tensor.h +++ b/mindspore/ccsrc/ir/tensor.h @@ -173,6 +173,9 @@ class Tensor : public MetaTensor { // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. bool ValueEqual(const Tensor &other) const; + // assgin value to this tensor + Tensor &AssignValue(const Tensor &tensor); + bool operator==(const Value &other) const override { if (other.isa()) { auto other_ = static_cast(other); diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 69affee2c32..150dc25d19e 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -203,6 +203,8 @@ class Parameter: return self.default_input / other def __setitem__(self, index, value): + default_input = self.default_input + default_input[index] = value return self def set_parameter_data(self, data): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index b7452d8165a..0a631b954fd 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -150,6 +150,8 @@ class Tensor(Tensor_): return out def __setitem__(self, index, value): + out = tensor_operator_registry.get('__setitem__')(self, index, value) + self.assign_value(out) return self def __gt__(self, other): diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 826cb9500c3..906d74948a3 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -26,7 +26,7 @@ hyper_map = base.HyperMap() pack = P.Pack(axis=-1) -def broadcast(broadcast_shape, x): +def _broadcast(broadcast_shape, x): """Broadcast tensor to the required shape.""" if F.shape(x) == broadcast_shape: return x @@ -36,13 +36,13 @@ def broadcast(broadcast_shape, x): return x -def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): +def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): """Transform indexing tensor to the required.""" - x = broadcast(broadcast_shape, x) - return broadcast(final_shape, F.reshape(x, new_shape)) + x = _broadcast(broadcast_shape, x) + return _broadcast(final_shape, F.reshape(x, new_shape)) -def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): +def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): """Generate an indices tensor from a tuple of tensor.""" indices = None check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) @@ -52,26 +52,31 @@ def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): if check_dtypes: shape_tuple = hyper_map(F.shape, tuple_index) broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) - broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) + broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) indices = pack(broadcast_tensors) return indices -def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): +def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" indexes_types = hyper_map(F.typeof, tuple_index) int_positions = const_utils.get_pos_of_int_index(indexes_types) - for i in int_positions: - tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) - indexes_types = hyper_map(F.typeof, tuple_index) + tuple_index_new = () + tuple_len = len(tuple_index) + for i in range(tuple_len): + if i in int_positions: + tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),) + else: + tuple_index_new = tuple_index_new + (tuple_index[i],) + indexes_types = hyper_map(F.typeof, tuple_index_new) tensor_positions, slice_positions, ellipsis_position = \ const_utils.separate_mixed_tensors_index(indexes_types, op_name) tensor_indexes = [] slice_indexes = [] for i in tensor_positions: - tensor_indexes.append(tuple_index[i]) + tensor_indexes.append(tuple_index_new[i]) for j in slice_positions: - slice_indexes.append(tuple_index[j]) + slice_indexes.append(tuple_index_new[j]) data_shape = F.shape(data) tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) @@ -85,14 +90,14 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): slice_number = 0 final_index_tensors = [] - tuple_index_size = len(tuple_index) + tuple_index_size = len(tuple_index_new) index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) for i in range(tuple_index_size): if i in tensor_positions: - transform_tensor = transform_indexing_tensor(broadcast_shape, - final_shape, - index_tensor_new_shape, - tuple_index[i]) + transform_tensor = _transform_indexing_tensor(broadcast_shape, + final_shape, + index_tensor_new_shape, + tuple_index_new[i]) final_index_tensors.append(transform_tensor) if i in slice_positions: slice_tensor = const_utils.convert_slice_to_tensor(slice_number, @@ -114,7 +119,7 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): return indices -def generate_updates_from_scalar(data, indices, value, op_type): +def _generate_updates_from_scalar(data, indices, value, op_type): """Generate an updates tensor from a scalar.""" data_shape = F.shape(data) indices_shape = F.shape(indices) @@ -122,7 +127,7 @@ def generate_updates_from_scalar(data, indices, value, op_type): return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) -def generate_updates_from_tuple(data, index, value, op_type): +def _generate_updates_from_tuple(data, index, value, op_type): """Generate an updates tensor from a tuple.""" value_types = hyper_map(F.typeof, value) data_dtype = F.dtype(data) @@ -132,14 +137,14 @@ def generate_updates_from_tuple(data, index, value, op_type): shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) if shapes_same: value = F.pack(value) - return generate_updates_from_tensor(data, index, value, op_type) + return _generate_updates_from_tensor(data, index, value, op_type) data_shape = F.shape(data) index_shape = F.shape(index) return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) -def generate_updates_from_tensor(data, index, value, op_type): +def _generate_updates_from_tensor(data, index, value, op_type): """Generate an updates tensor from a tensor.""" data_shape = F.shape(data) index_shape = F.shape(index) @@ -152,45 +157,47 @@ def generate_updates_from_tensor(data, index, value, op_type): updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) if need_broadcast: - return broadcast(updates_shape, value) + return _broadcast(updates_shape, value) return value -def tensor_getitem(self, index): +def _tensor_getitem(self, index): """Handle tensor getitem""" if isinstance(index, Tensor): return tensor_index_by_tensor(self, index) if isinstance(index, tuple): return tensor_index_by_tuple(self, index) if isinstance(index, int): - return tensor_index_by_integer(self, index) + return _tensor_index_by_integer(self, index) if isinstance(index, slice): return tensor_index_by_slice(self, index) if isinstance(index, bool): - return tensor_index_by_bool(self, index) + return _tensor_index_by_bool(self, index) + if index is None: + return F.expand_dims(self, 0) if index is ...: return self raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " f"got {index} with type {type(index)}.") -tensor_operator_registry.register("__getitem__", tensor_getitem) +tensor_operator_registry.register("__getitem__", _tensor_getitem) -def tensor_getitem_by_tuple_of_tensor(data, tuple_index): +def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): """Tensor getitem by a tuple of tensor.""" - indices = generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_GETITEM) + indices = _generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_GETITEM) result = F.gather_nd(data, indices) return result -def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): +def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): """Tensor getitem by a tuple of mixed tensor.""" - indices = generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_GETITEM) + indices = _generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_GETITEM) result = F.gather_nd(data, indices) return result @@ -204,7 +211,7 @@ def tensor_index_by_slice(data, slice_index): return F.strided_slice(data, begin_strides, end_strides, step_strides) -def tensor_index_by_integer(data, number): +def _tensor_index_by_integer(data, number): """Tensor getitem by a single integer number""" shape = F.shape(data) if not shape: @@ -214,7 +221,7 @@ def tensor_index_by_integer(data, number): return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) -def tensor_index_by_bool(data, bool_value): +def _tensor_index_by_bool(data, bool_value): """Tensor getitem by a single bool value""" if bool_value: return F.expand_dims(data, 0) @@ -225,9 +232,9 @@ def tensor_index_by_number(data, number): """Tensor getitem by a Number which may be integer/float/bool value""" number_type = const_utils.check_number_index_type(number) if number_type == const_utils.BOOL_: - return tensor_index_by_bool(data, number) + return _tensor_index_by_bool(data, number) if number_type == const_utils.INT_: - return tensor_index_by_integer(data, number) + return _tensor_index_by_integer(data, number) return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") @@ -241,7 +248,7 @@ def tensor_index_by_tensor(data, tensor_index): "the index tensor data type only support mstype.int32.") -def tensor_index_by_tuple_slice(data, t): +def _tensor_index_by_tuple_slice(data, t): """Tensor getitem by a tuple of slice""" shape = F.shape(data) if len(t) > len(shape): @@ -257,7 +264,303 @@ def tensor_index_by_tuple(data, tuple_index): indexes_types = hyper_map(F.typeof, tuple_index) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) if index_elements_type == const_utils.NO_TENSOR: - return tensor_index_by_tuple_slice(data, tuple_index) + return _tensor_index_by_tuple_slice(data, tuple_index) if index_elements_type == const_utils.ALL_TENSOR: - return tensor_getitem_by_tuple_of_tensor(data, tuple_index) - return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) + + +def _tensor_setitem(self, index, value): + """Handle tensor getitem""" + if isinstance(index, Tensor): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tensor_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tensor_with_tensor(self, index, value) + if isinstance(value, tuple): + return tensor_setitem_by_tensor_with_tuple(self, index, value) + if isinstance(index, tuple): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tuple_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tuple_with_tensor(self, index, value) + if isinstance(value, tuple): + return tensor_setitem_by_tuple_with_tuple(self, index, value) + if isinstance(index, int): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_number_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_number_with_tensor(self, index, value) + if isinstance(index, slice): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_slice_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_slice_with_tensor(self, index, value) + if isinstance(index, bool): + return _tensor_index_by_bool(self, index) + if index is ...: + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_ellipsis_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_ellipsis_with_tensor(self, index, value) + raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ + and tensor with int32, got {} with type{}".format(index, type(index))) + + +tensor_operator_registry.register("__setitem__", _tensor_setitem) + + +def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): + """Set a tensor item by a int tensor with a tensor.""" + updates = _generate_updates_from_tensor(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) + index = F.expand_dims(index, -1) + return P.TensorScatterUpdate()(data, index, updates) + + +def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): + """Set a tensor item by a bool tensor with a tensor.""" + index_shape = F.shape(index) + data_shape = F.shape(data) + data_shape = const_utils.check_equal(data_shape, index_shape, + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + size = F.size(value) + size = const_utils.check_equal(1, size, + "When assign value is a tensor, its size should be {}, but current size is {}.") + dtype = F.dtype(data) + u_cast = F.cast(value, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + result = F.select(index, u, data) + return result + + +def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): + """setitem by tensor index(dtype is int or bool) with tensor as value""" + index_dtype = F.dtype(index) + tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == const_utils.INT_: + return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) + return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) + + +def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): + """Set a tensor item by a bool tensor with a scalar.""" + index_shape = F.shape(index) + shape = F.shape(data) + shape = const_utils.check_equal( + shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + return F.select(index, u, data) + + +def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): + """Set a tensor item by a int tensor with a scalar.""" + updates = _generate_updates_from_scalar(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) + index = F.expand_dims(index, -1) + return P.TensorScatterUpdate()(data, index, updates) + + +def tensor_setitem_by_tensor_with_number(data, index, value): + index_dtype = F.dtype(index) + tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == const_utils.BOOL_: + return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) + if tensor_dtype == const_utils.INT_: + return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) + return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") + + +def tensor_setitem_by_tensor_with_tuple(data, index, value): + """Assigns the tensor by tensor with tuple value.""" + index_dtype = F.dtype(index) + check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) + result = None + if check_dtype: + result = _tensor_setitem_by_tensor_with_tuple(data, index, value) + return result + + +def _tensor_indices_number(data, data_shape, index, indices, value): + """Assigns a scalar value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = const_utils.check_indices(indices_size, index) + update = F.fill(mstype.int32, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) + value_fill = F.fill(data_dtype, (indices_size,), value) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + + +def _tensor_setitem_by_tensor_with_tuple(data, index, value): + """Set a tensor item by a tensor with a tuple.""" + updates = _generate_updates_from_tuple(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) + index = F.expand_dims(index, -1) + result = P.TensorScatterUpdate()(data, index, updates) + return result + + +def tensor_setitem_by_slice_with_number(data, input_slice, value): + """Givens a scalar assign to tensor by slice""" + check_result = const_utils.check_tensor_setitem_index(input_slice) + result = None + if check_result: + data_shape = F.shape(data) + indices = const_utils.slice2indices(input_slice, data_shape) + is_tuple_int = const_utils.tuple_element_is_int(input_slice) + if is_tuple_int: + indices = const_utils.integer_to_indices(input_slice, data_shape) + result = _tensor_indices_number(data, data_shape, input_slice, indices, value) + return result + + +def tensor_setitem_by_tuple_with_number(data, tuple_index, value): + """Assigns the tensor by tuple with number value.""" + indexes_types = hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.NO_TENSOR: + return tensor_setitem_by_slice_with_number(data, tuple_index, value) + if index_elements_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = _generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_scalar(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return P.TensorScatterUpdate()(data, indices, updates) + + +def _tensor_indices_tensor(data, data_shape, index, indices, value): + """Assigns a tensor value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = const_utils.check_indices(indices_size, index) + update = F.fill(mstype.int32, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) + value_fill = None + value_size = F.size(value) + + value_size = const_utils.check_indices_value_size(indices_size, value_size) + if value_size == 1: + value_fill = F.fill(data_dtype, (indices_size,), 1) + value = F.cast(value, data_dtype) + value_fill = F.tensor_mul(value_fill, value) + elif value_size > 1: + value_fill = F.reshape(value, (indices_size,)) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + + +def tensor_setitem_by_slice_with_tensor(data, input_slice, value): + """Assigns a tensor value to the tensor by slice.""" + result = None + check_result = const_utils.check_tensor_setitem_index(input_slice) + if check_result: + data_shape = F.shape(data) + indices = const_utils.slice2indices(input_slice, data_shape) + is_tuple_int = const_utils.tuple_element_is_int(input_slice) + if is_tuple_int: + indices = const_utils.integer_to_indices(input_slice, data_shape) + result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) + return result + + +def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): + """Assigns the tensor by tuple with tensor value.""" + indexes_types = hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.NO_TENSOR: + return tensor_setitem_by_slice_with_tensor(data, tuple_index, value) + if index_elements_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = _generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_tensor(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return P.TensorScatterUpdate()(data, indices, updates) + + +def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): + """Assigns the tensor by tuple with tuple of value.""" + indexes_types = hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = _generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_tuple(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return P.TensorScatterUpdate()(data, indices, updates) + + +def tensor_setitem_by_number_with_number(data, index, value): + """Assigns the tensor by number with number value.""" + data_shape = F.shape(data) + indices = const_utils.integer_to_indices(index, data_shape) + return _tensor_indices_number(data, data_shape, index, indices, value) + + +def tensor_setitem_by_number_with_tensor(data, index, value): + """Assigns the tensor by number with tensor value.""" + data_shape = F.shape(data) + indices = const_utils.integer_to_indices(index, data_shape) + return _tensor_indices_tensor(data, data_shape, index, indices, value) + + +def tensor_setitem_by_ellipsis_with_number(data, index, value): + """Assigns the tensor by ellipsis with number value.""" + data_shape = F.shape(data) + data_dtype = F.dtype(data) + return F.fill(data_dtype, data_shape, value) + + +def tensor_setitem_by_ellipsis_with_tensor(data, index, value): + """Assigns the tensor by ellipsis with tensor value.""" + result = None + data_shape = F.shape(data) + data_dtype = F.dtype(data) + data_size = F.size(data) + value_shape = F.shape(value) + value_size = F.size(value) + check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) + if check_result: + if data_size == value_size: + result = F.reshape(value, data_shape) + result = F.cast(result, data_dtype) + elif value_size == 1: + param1 = F.fill(data_dtype, data_shape, 1) + param2 = F.cast(value, data_dtype) + result = F.tensor_mul(param1, param2) + return result diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 53659c62055..38cf0141f09 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -16,10 +16,8 @@ """Implementation for setitem.""" from . import _compile_utils as compile_utils -from . import _constexpr_utils as const_utils from ... import functional as F from ...composite import base -from ....common import dtype as mstype setitem = base.MultitypeFuncGraph('setitem') @@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): Outputs: Tensor, element type and shape is same as data. """ - index_dtype = F.dtype(index) - tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) - if tensor_dtype == const_utils.INT_: - return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) - return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) + return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor) @setitem.register("Tensor", "Tensor", "Number") @@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value): Outputs: Tensor, element type and shape is same as data. """ - index_dtype = F.dtype(index) - tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) - if tensor_dtype == const_utils.BOOL_: - return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) - return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) + return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value) @setitem.register("Tensor", "Tuple", "Number") @@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) - - if index_elements_type == const_utils.NO_TENSOR: - return _tensor_assgin_number(data, tuple_index, value) - if index_elements_type == const_utils.ALL_TENSOR: - indices = compile_utils.generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_SETITEM) - else: - indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_SETITEM) - updates = compile_utils.generate_updates_from_scalar(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - return F.scatter_nd_update(data, indices, updates) + return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value) @setitem.register("Tensor", "Tuple", "Tensor") @@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) - - if index_elements_type == const_utils.NO_TENSOR: - return _tensor_assgin_tensor(data, tuple_index, value) - if index_elements_type == const_utils.ALL_TENSOR: - indices = compile_utils.generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_SETITEM) - else: - indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_SETITEM) - updates = compile_utils.generate_updates_from_tensor(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - return F.scatter_nd_update(data, indices, updates) + return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) @setitem.register("Tensor", "Tuple", "Tuple") @@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) - - if index_elements_type == const_utils.ALL_TENSOR: - indices = compile_utils.generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_SETITEM) - else: - indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_SETITEM) - updates = compile_utils.generate_updates_from_tuple(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - return F.scatter_nd_update(data, indices, updates) + return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value) @setitem.register("Tensor", "Tensor", "Tuple") @@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value): Outputs: Tensor, element type and shape is same as data. """ - index_dtype = F.dtype(index) - check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) - result = None - if check_dtype: - result = _tensor_setitem_by_tensor_with_tuple(data, index, value) - return result + return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value) @setitem.register("Tensor", "Slice", "Tensor") @@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value): Outputs: Tensor, element type and shape is same as data. """ - return _tensor_assgin_tensor(data, input_slice, value) + return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value) @setitem.register("Tensor", "Slice", "Number") @@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): Outputs: Tensor, element type and shape is same as data. """ - return _tensor_assgin_number(data, input_slice, value) - - -def _tensor_assgin_number(data, input_slice, value): - """Givens a scalar assign to tensor by slice""" - check_result = const_utils.check_tensor_setitem_index(input_slice) - result = None - if check_result: - data_shape = F.shape(data) - indices = const_utils.slice2indices(input_slice, data_shape) - is_tuple_int = const_utils.tuple_element_is_int(input_slice) - if is_tuple_int: - indices = const_utils.integer_to_indices(input_slice, data_shape) - result = _tensor_indices_number(data, data_shape, input_slice, indices, value) - return result + return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value) @setitem.register("Tensor", "Number", "Number") def _tensor_setitem_with_int_v1(data, index, value): """Syntax: A[1] = 3""" - data_shape = F.shape(data) - indices = const_utils.integer_to_indices(index, data_shape) - return _tensor_indices_number(data, data_shape, index, indices, value) + return compile_utils.tensor_setitem_by_number_with_number(data, index, value) @setitem.register("Tensor", "Number", "Tensor") def _tensor_setitem_with_int_v2(data, index, value): """Syntax: A[1] = Tensor""" - data_shape = F.shape(data) - indices = const_utils.integer_to_indices(index, data_shape) - return _tensor_indices_tensor(data, data_shape, index, indices, value) + return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value) @setitem.register("Tensor", "Ellipsis", "Number") def _tensor_setitem_with_ellipsis_v1(data, index, value): """Syntax: A[...] = number.""" - data_shape = F.shape(data) - data_dtype = F.dtype(data) - return F.fill(data_dtype, data_shape, value) + return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value) @setitem.register("Tensor", "Ellipsis", "Tensor") def _tensor_setitem_with_ellipsis_v2(data, index, value): """Syntax: A[...] = Tensor.""" - result = None - data_shape = F.shape(data) - data_dtype = F.dtype(data) - data_size = F.size(data) - value_shape = F.shape(value) - value_size = F.size(value) - check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) - if check_result: - if data_size == value_size: - result = F.reshape(value, data_shape) - result = F.cast(result, data_dtype) - elif value_size == 1: - param1 = F.fill(data_dtype, data_shape, 1) - param2 = F.cast(value, data_dtype) - result = F.tensor_mul(param1, param2) - return result - - -def _tensor_assgin_tensor(data, input_slice, value): - """Assigns a tensor value to the tensor by slice.""" - result = None - check_result = const_utils.check_tensor_setitem_index(input_slice) - if check_result: - data_shape = F.shape(data) - indices = const_utils.slice2indices(input_slice, data_shape) - is_tuple_int = const_utils.tuple_element_is_int(input_slice) - if is_tuple_int: - indices = const_utils.integer_to_indices(input_slice, data_shape) - result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) - return result - - -def _tensor_indices_tensor(data, data_shape, index, indices, value): - """Assigns a tensor value to the tensor.""" - data_size = F.size(data) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = const_utils.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = None - value_size = F.size(value) - - value_size = const_utils.check_indices_value_size(indices_size, value_size) - if value_size == 1: - value_fill = F.fill(data_dtype, (indices_size,), 1) - value = F.cast(value, data_dtype) - value_fill = F.tensor_mul(value_fill, value) - elif value_size > 1: - value_fill = F.reshape(value, (indices_size,)) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) - - -def _tensor_indices_number(data, data_shape, index, indices, value): - """Assigns a scalar value to the tensor.""" - data_size = F.size(data) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = const_utils.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = F.fill(data_dtype, (indices_size,), value) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) - - -def _tensor_setitem_by_tensor_with_tuple(data, index, value): - """Set a tensor item by a tensor with a tuple.""" - updates = compile_utils.generate_updates_from_tuple(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) - result = F.scatter_update(data, index, updates) - return result - - -def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): - """Set a tensor item by a int tensor with a scalar.""" - updates = compile_utils.generate_updates_from_scalar(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) - return F.scatter_update(data, index, updates) - - -def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): - """Set a tensor item by a bool tensor with a scalar.""" - index_shape = F.shape(index) - shape = F.shape(data) - shape = const_utils.check_equal( - shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") - dtype = F.dtype(data) - u = F.fill(dtype, shape, value) - return F.select(index, u, data) - - -def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): - """Set a tensor item by a int tensor with a tensor.""" - updates = compile_utils.generate_updates_from_tensor(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) - return F.scatter_update(data, index, updates) - - -def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): - """Set a tensor item by a bool tensor with a tensor.""" - index_shape = F.shape(index) - data_shape = F.shape(data) - data_shape = const_utils.check_equal(data_shape, index_shape, - "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") - size = F.size(value) - size = const_utils.check_equal(1, size, - "When assign value is a tensor, its size should be {}, but current size is {}.") - dtype = F.dtype(data) - u_cast = F.cast(value, dtype) - one_data = F.ones_like(data) - u = F.tensor_mul(one_data, u_cast) - result = F.select(index, u, data) - return result + return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value) diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_index.py index 879b4f4c2e5..77ee7db5d65 100644 --- a/tests/st/pynative/test_tensor_index.py +++ b/tests/st/pynative/test_tensor_index.py @@ -20,10 +20,14 @@ from mindspore import Tensor, Parameter from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell +from mindspore.common.parameter import ParameterTuple +from mindspore.ops import composite as C + def setup_module(): context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + class NetWorkSlicePositive(Cell): def __init__(self): super(NetWorkSlicePositive, self).__init__() @@ -139,7 +143,7 @@ class TensorGetItemByThreeTensors(Cell): return ret0, ret1, ret2 -def Xtest_getitem_by_tensors(): +def test_getitem_by_tensors(): net = TensorGetItemByThreeTensors() input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) @@ -155,119 +159,140 @@ def Xtest_getitem_by_tensors(): assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5])) -class TensorGetItemByMixedTensors_0(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_0, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) +class TensorGetItemByMixedTensorsBasicCase(Cell): + def __init__(self, c0, c1, c2, c3, c4, c5): + super(TensorGetItemByMixedTensorsBasicCase, self).__init__() + self.const0 = Tensor(c0) + self.const1 = Tensor(c1) + self.const2 = Tensor(c2) + self.const3 = Tensor(c3) + self.const4 = Tensor(c4) + self.const5 = Tensor(c5) def construct(self, tensor, index_0, index_1): - ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const - return ret + ret0 = tensor[index_0, index_1, 0:3] + self.const0 + ret1 = tensor[0:3, index_0, ...] + self.const1 + ret2 = tensor[0, index_0, index_1] + self.const2 + ret3 = tensor[..., index_0, 0:3] + self.const3 + ret4 = tensor[0:2, index_0, index_1] + self.const4 + ret5 = tensor[..., index_0, index_1] + self.const5 + return ret0, ret1, ret2, ret3, ret4, ret5 -class TensorGetItemByMixedTensors_1(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_1, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) - - def construct(self, tensor, index_0, index_1): - ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const - return ret - - -class TensorGetItemByMixedTensors_2(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_2, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) - - def construct(self, tensor, index_0, index_1): - ret = tensor[0, index_0, index_1, ..., 3] + self.const - return ret - - -class TensorGetItemByMixedTensors_3(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_3, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) - - def construct(self, tensor, index_0, index_1): - ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const - return ret - - -class TensorGetItemByMixedTensors_4(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_4, self).__init__() - self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) - - def construct(self, tensor, index_0, index_1, index_2): - ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const - return ret - - -class TensorGetItemByMixedTensors_5(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_5, self).__init__() - self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) - - def construct(self, tensor, index_0, index_1, index_2): - ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const - return ret - - -class TensorGetItemByMixedTensors_6(Cell): - def __init__(self): - super(TensorGetItemByMixedTensors_6, self).__init__() - self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) - - def construct(self, tensor, index_0, index_1, index_2): - ret = tensor[..., index_0, index_1, index_2, 3] + self.const - return ret +def test_getitem_by_mixed_tensors(): + const0 = np.ones((3, 4, 5, 3), np.float32) + const1 = np.ones((3, 3, 4, 5, 5), np.float32) + const2 = np.ones((3, 4, 5), np.float32) + const3 = np.ones((3, 3, 4, 5, 3), np.float32) + const4 = np.ones((2, 3, 4, 5), np.float32) + const5 = np.ones((3, 3, 4, 5), np.float32) + net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5) + input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) + input_ms = Tensor(input_np, mstype.float32) + index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32) + index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32) + index_0 = Tensor(index_np_0, mstype.int32) + index_1 = Tensor(index_np_1, mstype.int32) + out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1) + assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0)) + assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1)) + assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2)) + assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3)) + assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4)) + assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5)) class TensorSetItemByMixedTensors_0(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_0, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) - self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), + self.const = Tensor(np.ones((3, 4, 5), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), name="x") self.value = value def construct(self, index_0, index_1, index_2): - self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value + self.param[0:2, index_0, index_1] = self.value ret = self.param + self.const return ret +def test_setitem_by_mixed_tensors_0(): + value = 88.0 + net = TensorSetItemByMixedTensors_0(value) + index_0 = np.random.randint(3, size=(3, 4, 5)) + index_1 = np.random.randint(4, size=(4, 5)) + index_2 = np.random.randint(3, size=(2, 1, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) + const = np.ones((3, 4, 5), np.float32) + out = net(index_0_ms, index_1_ms, index_2_ms) + input_np[0:2, index_0, index_1] = value + assert np.all(out.asnumpy() == (input_np + const)) + + class TensorSetItemByMixedTensors_1(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_1, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) - self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + self.const = Tensor(np.ones((3, 4, 5), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), name="x") self.value = value def construct(self, index_0, index_1, index_2): - self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value + self.param[0:2, index_0, ...] = self.value ret = self.param + self.const return ret +def test_setitem_by_mixed_tensors_1(): + value = 88.0 + net = TensorSetItemByMixedTensors_1(value) + index_0 = np.random.randint(3, size=(3, 4, 5)) + index_1 = np.random.randint(4, size=(4, 5)) + index_2 = np.random.randint(3, size=(2, 1, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) + const = np.ones((3, 4, 5), np.float32) + out = net(index_0_ms, index_1_ms, index_2_ms) + input_np[0:2, index_0, ...] = value + assert np.all(out.asnumpy() == (input_np + const)) + + class TensorSetItemByMixedTensors_2(Cell): def __init__(self, value): super(TensorSetItemByMixedTensors_2, self).__init__() - self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) - self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), + self.const = Tensor(np.ones((3, 4, 5), np.float16)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16), name="x") self.value = value def construct(self, index_0, index_1, index_2): - self.param[..., index_0, index_1, index_2, 3] = self.value + self.param[..., index_0, 1] = self.value ret = self.param + self.const return ret +def test_setitem_by_mixed_tensors_2(): + value = 88.0 + net = TensorSetItemByMixedTensors_2(value) + index_0 = np.random.randint(3, size=(3, 4, 5)) + index_1 = np.random.randint(4, size=(4, 5)) + index_2 = np.random.randint(3, size=(2, 1, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32) + const = np.ones((3, 4, 5), np.float32) + out = net(index_0_ms, index_1_ms, index_2_ms) + input_np[..., index_0, 1] = value + assert np.all(out.asnumpy() == (input_np + const)) + + class TensorGetItemByMixedTensorsTypeError(Cell): def __init__(self): super(TensorGetItemByMixedTensorsTypeError, self).__init__() @@ -277,13 +302,13 @@ class TensorGetItemByMixedTensorsTypeError(Cell): return ret -class TensorGetItemByMixedTensorsNumberError(Cell): - def __init__(self): - super(TensorGetItemByMixedTensorsNumberError, self).__init__() - - def construct(self, x, index_0, index_1): - ret = x[index_0, index_1, 0:3, ..., index_1, index_0] - return ret +def test_getitem_by_mixedtensor_exception(): + input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32) + index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32) + index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32) + net1 = TensorGetItemByMixedTensorsTypeError() + with pytest.raises(TypeError): + net1(input_ms, index_0, index_1) class TensorSetItemByOneTensorWithNumber(Cell): @@ -299,6 +324,18 @@ class TensorSetItemByOneTensorWithNumber(Cell): return ret +def test_setitem_one_tensor_with_number(): + value = 0.0 + net = TensorSetItemByOneTensorWithNumber(value) + index_np = np.random.randint(4, size=(5, 4)) + index = Tensor(index_np, mstype.int32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)) + const = np.ones((6, 7, 8)).astype(np.float32) + out = net(index) + input_data[index_np] = value + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByOneTensorWithTensor(Cell): def __init__(self): super(TensorSetItemByOneTensorWithTensor, self).__init__() @@ -311,6 +348,19 @@ class TensorSetItemByOneTensorWithTensor(Cell): return ret +def test_setitem_by_one_tensor_with_tensor(): + net = TensorSetItemByOneTensorWithTensor() + index_np = np.random.randint(4, size=(5, 4)) + index = Tensor(index_np, mstype.int32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)) + const = np.ones((6, 7, 8)).astype(np.float32) + value = np.zeros((4, 7, 8)).astype(np.float32) + value_ms = Tensor(value, mstype.float32) + out = net(index, value_ms) + input_data[index_np] = value + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByOneTensorWithTupleOfNumber(Cell): def __init__(self, value): super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() @@ -324,6 +374,18 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): return ret +def test_setitem_by_one_tensor_with_tuple_number(): + value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7) + net = TensorSetItemByOneTensorWithTupleOfNumber(value) + input_np = np.random.randint(5, size=(5, 4)) + input_ms = Tensor(input_np, mstype.int32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) + const = np.ones((6, 7, 8)).astype(np.float32) + out = net(input_ms) + input_data[input_np] = value + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByOneTensorWithTupleOfTensor(Cell): def __init__(self): super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() @@ -336,6 +398,23 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): return ret +def test_setitem_by_one_tensor_with_tuple_tensors(): + net = TensorSetItemByOneTensorWithTupleOfTensor() + input_np = np.random.randint(6, size=(5, 4)).astype(np.int32) + input_ms = Tensor(input_np, mstype.int32) + input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32) + value_0_np = np.zeros((8,), np.float32) + value_1_np = np.ones((8,), np.float32) + value_2_np = np.ones((8,), np.float32)*2 + value_0 = Tensor(value_0_np) + value_1 = Tensor(value_1_np) + value_2 = Tensor(value_2_np) + const = np.ones((6, 3, 8)).astype(np.float32) + out = net(input_ms, value_0, value_1, value_2) + input_data[input_np] = (value_0_np, value_1_np, value_2_np) + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByTensorsWithNumber(Cell): def __init__(self, value): super(TensorSetItemByTensorsWithNumber, self).__init__() @@ -349,6 +428,22 @@ class TensorSetItemByTensorsWithNumber(Cell): return ret +def test_setitem_by_tensors_with_number(): + value = 0.0 + net = TensorSetItemByTensorsWithNumber(value) + index_0 = np.random.randint(6, size=(3, 4, 5)) + index_1 = np.random.randint(7, size=(4, 5)) + index_2 = np.random.randint(8, size=(5, 3, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + out = net(index_0_ms, index_1_ms, index_2_ms) + const = np.ones((6, 7, 8)).astype(np.float32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) + input_data[index_0, index_1, index_2] = value + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByTensorsWithTensor(Cell): def __init__(self): super(TensorSetItemByTensorsWithTensor, self).__init__() @@ -361,6 +456,23 @@ class TensorSetItemByTensorsWithTensor(Cell): return ret +def test_setitem_by_tensors_with_tensor(): + net = TensorSetItemByTensorsWithTensor() + index_0 = np.random.randint(6, size=(3, 4, 5)) + index_1 = np.random.randint(7, size=(4, 5)) + index_2 = np.random.randint(8, size=(5, 3, 4, 5)) + value = np.zeros((4, 5)).astype(np.float32) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + value_ms = Tensor(value, mstype.float32) + out = net(index_0_ms, index_1_ms, index_2_ms, value_ms) + const = np.ones((6, 7, 8)).astype(np.float32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) + input_data[index_0, index_1, index_2] = value + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByTensorsWithTensorNumberError(Cell): def __init__(self): super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() @@ -373,6 +485,17 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): return ret +def test_setitem_by_tensors_with_tensor_error(): + index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32) + index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32) + index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32) + index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32) + value = Tensor(np.zeros((2, 5)), mstype.float32) + net = TensorSetItemByTensorsWithTensorNumberError() + with pytest.raises(IndexError): + net(index_0, index_1, index_2, index_3, value) + + class TensorSetItemByTensorsWithTupleOfNumber(Cell): def __init__(self, value): super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() @@ -386,6 +509,22 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): return ret +def test_setitem_by_tensors_with_tuple_of_number(): + value = (0.0, 1.1, 2.2, 3.3, 4.4) + net = TensorSetItemByTensorsWithTupleOfNumber(value) + index_0 = np.random.randint(6, size=(3, 4, 5)) + index_1 = np.random.randint(7, size=(4, 5)) + index_2 = np.random.randint(8, size=(5, 3, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) + input_data[index_0, index_1, index_2] = value + const = np.ones((6, 7, 8)).astype(np.float32) + out = net(index_0_ms, index_1_ms, index_2_ms) + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByTensorsWithTupleOfTensor(Cell): def __init__(self): super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() @@ -398,6 +537,27 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): return ret +def test_setitem_by_tensors_with_tuple_of_tensor(): + value_0 = np.zeros((4, 5)) + value_1 = np.ones((4, 5)) + value_2 = np.ones((4, 5)) * 2 + value_0_ms = Tensor(value_0, mstype.float32) + value_1_ms = Tensor(value_1, mstype.float32) + value_2_ms = Tensor(value_2, mstype.float32) + net = TensorSetItemByTensorsWithTupleOfTensor() + index_0 = np.random.randint(6, size=(3, 4, 5)) + index_1 = np.random.randint(7, size=(4, 5)) + index_2 = np.random.randint(8, size=(5, 3, 4, 5)) + index_0_ms = Tensor(index_0, mstype.int32) + index_1_ms = Tensor(index_1, mstype.int32) + index_2_ms = Tensor(index_2, mstype.int32) + input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32) + input_data[index_0, index_1, index_2] = (value_0, value_1, value_2) + const = np.ones((6, 7, 8)).astype(np.float32) + out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms) + assert np.all(out.asnumpy() == (input_data + const)) + + class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): def __init__(self): super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() @@ -410,17 +570,44 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): return ret -class TensorSetItemByMixedTensors(Cell): - def __init__(self): - super(TensorSetItemByMixedTensors, self).__init__() - self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") - self.value = 99.0 +def test_setitem_by_tensor_with_tuple_of_tensor_error(): + net = TensorSetItemByTensorsWithTupleOfTensorNumberError() + index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32) + index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32) + index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32) + value_0 = np.zeros((4, 5)) + value_1 = np.ones((4, 5)) + value_0_ms = Tensor(value_0, mstype.float32) + value_1_ms = Tensor(value_1, mstype.float32) + with pytest.raises(ValueError): + net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms) - def construct(self, index_0, index_1): - self.param[index_0, index_1, 0:6] = self.value - ret = self.param + self.const - return ret + +def test_setitem_grad(): + class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.weight = Parameter( + Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True) + + def construct(self, a, b): + a[1:3:1, ::] = b + c = a + self.weight + return c + + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, x, y, sens): + return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) + net = GradNet(Net()) + x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32) + y = Tensor(np.array([3]).astype(np.float32), mstype.float32) + sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32) + net(x, y, sens) class TensorAssignWithSliceError1(Cell): @@ -475,7 +662,6 @@ class TensorAssignWithSlice(Cell): def test_tensor_assign(): - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() net2 = TensorAssignWithSlice2() net_e1 = TensorAssignWithSliceError1() @@ -621,7 +807,7 @@ class TensorAssignWithTupleInteger(Cell): class TensorAssignWithBoolTensorIndex(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex, self).__init__() - self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) + self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) self.u_scalar = 5 def construct(self, a, b, c, u_tensor): @@ -643,8 +829,7 @@ class TensorAssignWithBoolTensorIndexError(Cell): class TensorAssignWithBoolTensorIndex2(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex2, self).__init__() - self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) - self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) + self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) self.u_scalar = 5 def construct(self, a, u_tensor): @@ -666,7 +851,40 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): return a -def Xtest_tensor_assign_bool_index(): +def test_tensor_assign_bool_index_0(): + a = np.arange(60).reshape(3, 4, 5) + b = a > 5 + c = a < 3 + Ta = Tensor(a, dtype=mstype.float32) + Tb = Tensor(b) + Tc = Tensor(c) + u_tensor = Tensor([1], dtype=mstype.float32) + net1 = TensorAssignWithBoolTensorIndex() + out = net1(Ta, Tb, Tc, u_tensor) + res = np.arange(60).reshape(3, 4, 5) + res[c] = 5 + res[b] = 1 + res = res + np.ones([3, 4, 5]) + assert np.all(out.asnumpy() == res) + + +def test_tensor_assign_bool_index_1(): + a = np.arange(60).reshape(3, 4, 5) + Ta = Tensor(a, dtype=mstype.float32) + u_tensor = Tensor([1], dtype=mstype.float32) + net2 = TensorAssignWithBoolTensorIndex2() + out = net2(Ta, u_tensor) + res = np.arange(60).reshape(3, 4, 5) + res[res > 8] = 1 + res[res >= 6] = 5 + res[res < 3] = 5 + res[res <= 5] = 1 + res[res == 5] = 5 + res = res + np.ones([3, 4, 5]) + assert np.all(out.asnumpy() == res) + + +def test_tensor_assign_bool_index_exception(): a = np.arange(60).reshape(3, 4, 5) b = a > 5 c = a < 3 @@ -679,8 +897,6 @@ def Xtest_tensor_assign_bool_index(): u_scalar = 5 net1 = TensorAssignWithBoolTensorIndex() net2 = TensorAssignWithBoolTensorIndex2() - net1(Ta, Tb, Tc, u_tensor) - net1(Ta, Tb, Tc, u_tensor) with pytest.raises(ValueError): net1(Ta, Td, Tc, u_tensor) with pytest.raises(IndexError): @@ -695,14 +911,14 @@ def Xtest_tensor_assign_bool_index(): with pytest.raises(ValueError): net2(Ta, u_tensor_error) net3 = TensorAssignWithBoolTensorIndexError() - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net3(Ta, Tb, Tc, u_tensor) - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net3(Ta, Tb, Tc, u_scalar) net4 = TensorAssignWithBoolTensorIndex2Error() - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net4(Ta, u_tensor) - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net4(Ta, u_scalar)