add support for parameter

support for tensor setitem

add support for tensor assgin
This commit is contained in:
huangdongrun 2020-06-12 16:08:54 +08:00
parent c55b81e94f
commit 79058d3509
7 changed files with 699 additions and 360 deletions

View File

@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
}
return *this;
}
Tensor &Tensor::AssignValue(const Tensor &tensor) {
*this = tensor;
return *this;
}
bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32)
mindspore.int32
)mydelimiter")
.def("assign_value", &Tensor::AssignValue, R"mydelimiter(
Assign another tensor value to this.
Arg:
value (:class:`mindspore.tensor`): The value tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
>>> data.assign_value(data2)
>>> data.shape
(2, 2)
)mydelimiter")
.def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle(

View File

@ -173,6 +173,9 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const;
// assgin value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other);

View File

@ -203,6 +203,8 @@ class Parameter:
return self.default_input / other
def __setitem__(self, index, value):
default_input = self.default_input
default_input[index] = value
return self
def set_parameter_data(self, data):

View File

@ -150,6 +150,8 @@ class Tensor(Tensor_):
return out
def __setitem__(self, index, value):
out = tensor_operator_registry.get('__setitem__')(self, index, value)
self.assign_value(out)
return self
def __gt__(self, other):

View File

@ -26,7 +26,7 @@ hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
def broadcast(broadcast_shape, x):
def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
return x
@ -36,13 +36,13 @@ def broadcast(broadcast_shape, x):
return x
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape))
x = _broadcast(broadcast_shape, x)
return _broadcast(final_shape, F.reshape(x, new_shape))
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
@ -52,26 +52,31 @@ def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions:
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
indexes_types = hyper_map(F.typeof, tuple_index)
tuple_index_new = ()
tuple_len = len(tuple_index)
for i in range(tuple_len):
if i in int_positions:
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
else:
tuple_index_new = tuple_index_new + (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index[i])
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index[j])
slice_indexes.append(tuple_index_new[j])
data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
@ -85,14 +90,14 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index)
tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape,
transform_tensor = _transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index[i])
tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
@ -114,7 +119,7 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
def _generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
@ -122,7 +127,7 @@ def generate_updates_from_scalar(data, indices, value, op_type):
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
def _generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
@ -132,14 +137,14 @@ def generate_updates_from_tuple(data, index, value, op_type):
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
return _generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
def _generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
@ -152,43 +157,45 @@ def generate_updates_from_tensor(data, index, value, op_type):
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return broadcast(updates_shape, value)
return _broadcast(updates_shape, value)
return value
def tensor_getitem(self, index):
def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, int):
return tensor_index_by_integer(self, index)
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if isinstance(index, bool):
return tensor_index_by_bool(self, index)
return _tensor_index_by_bool(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.")
tensor_operator_registry.register("__getitem__", tensor_getitem)
tensor_operator_registry.register("__getitem__", _tensor_getitem)
def tensor_getitem_by_tuple_of_tensor(data, tuple_index):
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = generate_indices_from_tuple_of_tensor(data,
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = generate_indices_from_tuple_of_mixed_tensors(data,
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
@ -204,7 +211,7 @@ def tensor_index_by_slice(data, slice_index):
return F.strided_slice(data, begin_strides, end_strides, step_strides)
def tensor_index_by_integer(data, number):
def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
shape = F.shape(data)
if not shape:
@ -214,7 +221,7 @@ def tensor_index_by_integer(data, number):
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_index_by_bool(data, bool_value):
def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
@ -225,9 +232,9 @@ def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return tensor_index_by_bool(data, number)
return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return tensor_index_by_integer(data, number)
return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")
@ -241,7 +248,7 @@ def tensor_index_by_tensor(data, tensor_index):
"the index tensor data type only support mstype.int32.")
def tensor_index_by_tuple_slice(data, t):
def _tensor_index_by_tuple_slice(data, t):
"""Tensor getitem by a tuple of slice"""
shape = F.shape(data)
if len(t) > len(shape):
@ -257,7 +264,303 @@ def tensor_index_by_tuple(data, tuple_index):
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_index_by_tuple_slice(data, tuple_index)
return _tensor_index_by_tuple_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))
tensor_operator_registry.register("__setitem__", _tensor_setitem)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = _generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
"""setitem by tensor index(dtype is int or bool) with tensor as value"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = _generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
def tensor_setitem_by_tensor_with_number(data, index, value):
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int")
def tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Assigns the tensor by tensor with tuple value."""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
result = P.TensorScatterUpdate()(data, index, updates)
return result
def tensor_setitem_by_slice_with_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""Assigns the tensor by tuple with tuple of value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def tensor_setitem_by_number_with_number(data, index, value):
"""Assigns the tensor by number with number value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
def tensor_setitem_by_number_with_tensor(data, index, value):
"""Assigns the tensor by number with tensor value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
def tensor_setitem_by_ellipsis_with_number(data, index, value):
"""Assigns the tensor by ellipsis with number value."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
"""Assigns the tensor by ellipsis with tensor value."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result

View File

@ -16,10 +16,8 @@
"""Implementation for setitem."""
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
from ....common import dtype as mstype
setitem = base.MultitypeFuncGraph('setitem')
@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor)
@setitem.register("Tensor", "Tensor", "Number")
@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
@setitem.register("Tensor", "Tuple", "Number")
@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
@setitem.register("Tensor", "Tuple", "Tensor")
@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
@setitem.register("Tensor", "Tuple", "Tuple")
@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value)
@setitem.register("Tensor", "Tensor", "Tuple")
@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value)
@setitem.register("Tensor", "Slice", "Tensor")
@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_tensor(data, input_slice, value)
return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value)
@setitem.register("Tensor", "Slice", "Number")
@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return _tensor_assgin_number(data, input_slice, value)
def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value)
@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
return compile_utils.tensor_setitem_by_number_with_number(data, index, value)
@setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value)
@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value)
@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = compile_utils.generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = compile_utils.generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = compile_utils.generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value)

View File

@ -20,10 +20,14 @@ from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import dtype as mstype
from mindspore.nn import Cell
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class NetWorkSlicePositive(Cell):
def __init__(self):
super(NetWorkSlicePositive, self).__init__()
@ -139,7 +143,7 @@ class TensorGetItemByThreeTensors(Cell):
return ret0, ret1, ret2
def Xtest_getitem_by_tensors():
def test_getitem_by_tensors():
net = TensorGetItemByThreeTensors()
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
@ -155,119 +159,140 @@ def Xtest_getitem_by_tensors():
assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
class TensorGetItemByMixedTensors_0(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
class TensorGetItemByMixedTensorsBasicCase(Cell):
def __init__(self, c0, c1, c2, c3, c4, c5):
super(TensorGetItemByMixedTensorsBasicCase, self).__init__()
self.const0 = Tensor(c0)
self.const1 = Tensor(c1)
self.const2 = Tensor(c2)
self.const3 = Tensor(c3)
self.const4 = Tensor(c4)
self.const5 = Tensor(c5)
def construct(self, tensor, index_0, index_1):
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
return ret
ret0 = tensor[index_0, index_1, 0:3] + self.const0
ret1 = tensor[0:3, index_0, ...] + self.const1
ret2 = tensor[0, index_0, index_1] + self.const2
ret3 = tensor[..., index_0, 0:3] + self.const3
ret4 = tensor[0:2, index_0, index_1] + self.const4
ret5 = tensor[..., index_0, index_1] + self.const5
return ret0, ret1, ret2, ret3, ret4, ret5
class TensorGetItemByMixedTensors_1(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_2(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0, index_0, index_1, ..., 3] + self.const
return ret
class TensorGetItemByMixedTensors_3(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_3, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_4(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_4, self).__init__()
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
return ret
class TensorGetItemByMixedTensors_5(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_5, self).__init__()
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
return ret
class TensorGetItemByMixedTensors_6(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_6, self).__init__()
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
return ret
def test_getitem_by_mixed_tensors():
const0 = np.ones((3, 4, 5, 3), np.float32)
const1 = np.ones((3, 3, 4, 5, 5), np.float32)
const2 = np.ones((3, 4, 5), np.float32)
const3 = np.ones((3, 3, 4, 5, 3), np.float32)
const4 = np.ones((2, 3, 4, 5), np.float32)
const5 = np.ones((3, 3, 4, 5), np.float32)
net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
input_ms = Tensor(input_np, mstype.float32)
index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32)
index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32)
index_0 = Tensor(index_np_0, mstype.int32)
index_1 = Tensor(index_np_1, mstype.int32)
out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1)
assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0))
assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1))
assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2))
assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3))
assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4))
assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5))
class TensorSetItemByMixedTensors_0(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
self.const = Tensor(np.ones((3, 4, 5), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)),
mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
self.param[0:2, index_0, index_1] = self.value
ret = self.param + self.const
return ret
def test_setitem_by_mixed_tensors_0():
value = 88.0
net = TensorSetItemByMixedTensors_0(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[0:2, index_0, index_1] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorSetItemByMixedTensors_1(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
self.const = Tensor(np.ones((3, 4, 5), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
self.param[0:2, index_0, ...] = self.value
ret = self.param + self.const
return ret
def test_setitem_by_mixed_tensors_1():
value = 88.0
net = TensorSetItemByMixedTensors_1(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[0:2, index_0, ...] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
self.const = Tensor(np.ones((3, 4, 5), np.float16))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[..., index_0, index_1, index_2, 3] = self.value
self.param[..., index_0, 1] = self.value
ret = self.param + self.const
return ret
def test_setitem_by_mixed_tensors_2():
value = 88.0
net = TensorSetItemByMixedTensors_2(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[..., index_0, 1] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
@ -277,13 +302,13 @@ class TensorGetItemByMixedTensorsTypeError(Cell):
return ret
class TensorGetItemByMixedTensorsNumberError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
return ret
def test_getitem_by_mixedtensor_exception():
input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
net1 = TensorGetItemByMixedTensorsTypeError()
with pytest.raises(TypeError):
net1(input_ms, index_0, index_1)
class TensorSetItemByOneTensorWithNumber(Cell):
@ -299,6 +324,18 @@ class TensorSetItemByOneTensorWithNumber(Cell):
return ret
def test_setitem_one_tensor_with_number():
value = 0.0
net = TensorSetItemByOneTensorWithNumber(value)
index_np = np.random.randint(4, size=(5, 4))
index = Tensor(index_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index)
input_data[index_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__()
@ -311,6 +348,19 @@ class TensorSetItemByOneTensorWithTensor(Cell):
return ret
def test_setitem_by_one_tensor_with_tensor():
net = TensorSetItemByOneTensorWithTensor()
index_np = np.random.randint(4, size=(5, 4))
index = Tensor(index_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
const = np.ones((6, 7, 8)).astype(np.float32)
value = np.zeros((4, 7, 8)).astype(np.float32)
value_ms = Tensor(value, mstype.float32)
out = net(index, value_ms)
input_data[index_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
@ -324,6 +374,18 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
return ret
def test_setitem_by_one_tensor_with_tuple_number():
value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)
net = TensorSetItemByOneTensorWithTupleOfNumber(value)
input_np = np.random.randint(5, size=(5, 4))
input_ms = Tensor(input_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(input_ms)
input_data[input_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
@ -336,6 +398,23 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
return ret
def test_setitem_by_one_tensor_with_tuple_tensors():
net = TensorSetItemByOneTensorWithTupleOfTensor()
input_np = np.random.randint(6, size=(5, 4)).astype(np.int32)
input_ms = Tensor(input_np, mstype.int32)
input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32)
value_0_np = np.zeros((8,), np.float32)
value_1_np = np.ones((8,), np.float32)
value_2_np = np.ones((8,), np.float32)*2
value_0 = Tensor(value_0_np)
value_1 = Tensor(value_1_np)
value_2 = Tensor(value_2_np)
const = np.ones((6, 3, 8)).astype(np.float32)
out = net(input_ms, value_0, value_1, value_2)
input_data[input_np] = (value_0_np, value_1_np, value_2_np)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__()
@ -349,6 +428,22 @@ class TensorSetItemByTensorsWithNumber(Cell):
return ret
def test_setitem_by_tensors_with_number():
value = 0.0
net = TensorSetItemByTensorsWithNumber(value)
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
out = net(index_0_ms, index_1_ms, index_2_ms)
const = np.ones((6, 7, 8)).astype(np.float32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__()
@ -361,6 +456,23 @@ class TensorSetItemByTensorsWithTensor(Cell):
return ret
def test_setitem_by_tensors_with_tensor():
net = TensorSetItemByTensorsWithTensor()
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
value = np.zeros((4, 5)).astype(np.float32)
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
value_ms = Tensor(value, mstype.float32)
out = net(index_0_ms, index_1_ms, index_2_ms, value_ms)
const = np.ones((6, 7, 8)).astype(np.float32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
@ -373,6 +485,17 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell):
return ret
def test_setitem_by_tensors_with_tensor_error():
index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32)
value = Tensor(np.zeros((2, 5)), mstype.float32)
net = TensorSetItemByTensorsWithTensorNumberError()
with pytest.raises(IndexError):
net(index_0, index_1, index_2, index_3, value)
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
@ -386,6 +509,22 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell):
return ret
def test_setitem_by_tensors_with_tuple_of_number():
value = (0.0, 1.1, 2.2, 3.3, 4.4)
net = TensorSetItemByTensorsWithTupleOfNumber(value)
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
@ -398,6 +537,27 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell):
return ret
def test_setitem_by_tensors_with_tuple_of_tensor():
value_0 = np.zeros((4, 5))
value_1 = np.ones((4, 5))
value_2 = np.ones((4, 5)) * 2
value_0_ms = Tensor(value_0, mstype.float32)
value_1_ms = Tensor(value_1, mstype.float32)
value_2_ms = Tensor(value_2, mstype.float32)
net = TensorSetItemByTensorsWithTupleOfTensor()
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = (value_0, value_1, value_2)
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
@ -410,17 +570,44 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
return ret
class TensorSetItemByMixedTensors(Cell):
def __init__(self):
super(TensorSetItemByMixedTensors, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = 99.0
def test_setitem_by_tensor_with_tuple_of_tensor_error():
net = TensorSetItemByTensorsWithTupleOfTensorNumberError()
index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
value_0 = np.zeros((4, 5))
value_1 = np.ones((4, 5))
value_0_ms = Tensor(value_0, mstype.float32)
value_1_ms = Tensor(value_1, mstype.float32)
with pytest.raises(ValueError):
net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms)
def construct(self, index_0, index_1):
self.param[index_0, index_1, 0:6] = self.value
ret = self.param + self.const
return ret
def test_setitem_grad():
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(
Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True)
def construct(self, a, b):
a[1:3:1, ::] = b
c = a + self.weight
return c
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y, sens):
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
net = GradNet(Net())
x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
y = Tensor(np.array([3]).astype(np.float32), mstype.float32)
sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
net(x, y, sens)
class TensorAssignWithSliceError1(Cell):
@ -475,7 +662,6 @@ class TensorAssignWithSlice(Cell):
def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
net2 = TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1()
@ -621,7 +807,7 @@ class TensorAssignWithTupleInteger(Cell):
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.u_scalar = 5
def construct(self, a, b, c, u_tensor):
@ -643,8 +829,7 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.u_scalar = 5
def construct(self, a, u_tensor):
@ -666,7 +851,40 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return a
def Xtest_tensor_assign_bool_index():
def test_tensor_assign_bool_index_0():
a = np.arange(60).reshape(3, 4, 5)
b = a > 5
c = a < 3
Ta = Tensor(a, dtype=mstype.float32)
Tb = Tensor(b)
Tc = Tensor(c)
u_tensor = Tensor([1], dtype=mstype.float32)
net1 = TensorAssignWithBoolTensorIndex()
out = net1(Ta, Tb, Tc, u_tensor)
res = np.arange(60).reshape(3, 4, 5)
res[c] = 5
res[b] = 1
res = res + np.ones([3, 4, 5])
assert np.all(out.asnumpy() == res)
def test_tensor_assign_bool_index_1():
a = np.arange(60).reshape(3, 4, 5)
Ta = Tensor(a, dtype=mstype.float32)
u_tensor = Tensor([1], dtype=mstype.float32)
net2 = TensorAssignWithBoolTensorIndex2()
out = net2(Ta, u_tensor)
res = np.arange(60).reshape(3, 4, 5)
res[res > 8] = 1
res[res >= 6] = 5
res[res < 3] = 5
res[res <= 5] = 1
res[res == 5] = 5
res = res + np.ones([3, 4, 5])
assert np.all(out.asnumpy() == res)
def test_tensor_assign_bool_index_exception():
a = np.arange(60).reshape(3, 4, 5)
b = a > 5
c = a < 3
@ -679,8 +897,6 @@ def Xtest_tensor_assign_bool_index():
u_scalar = 5
net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor)
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor)
with pytest.raises(IndexError):
@ -695,14 +911,14 @@ def Xtest_tensor_assign_bool_index():
with pytest.raises(ValueError):
net2(Ta, u_tensor_error)
net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net3(Ta, Tb, Tc, u_tensor)
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net3(Ta, Tb, Tc, u_scalar)
net4 = TensorAssignWithBoolTensorIndex2Error()
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net4(Ta, u_tensor)
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net4(Ta, u_scalar)