forked from OSSInnovation/mindspore
add support for parameter
support for tensor setitem add support for tensor assgin
This commit is contained in:
parent
c55b81e94f
commit
79058d3509
|
@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
|
|||
}
|
||||
return *this;
|
||||
}
|
||||
Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
||||
*this = tensor;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool Tensor::operator==(const Tensor &tensor) const {
|
||||
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
|
||||
|
@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
>>> data.set_dtype(mindspore.int32)
|
||||
mindspore.int32
|
||||
)mydelimiter")
|
||||
.def("assign_value", &Tensor::AssignValue, R"mydelimiter(
|
||||
Assign another tensor value to this.
|
||||
|
||||
Arg:
|
||||
value (:class:`mindspore.tensor`): The value tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
|
||||
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
|
||||
>>> data.assign_value(data2)
|
||||
>>> data.shape
|
||||
(2, 2)
|
||||
)mydelimiter")
|
||||
.def("__str__", &Tensor::ToString)
|
||||
.def("__repr__", &Tensor::ToStringRepr)
|
||||
.def(py::pickle(
|
||||
|
|
|
@ -173,6 +173,9 @@ class Tensor : public MetaTensor {
|
|||
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
|
||||
bool ValueEqual(const Tensor &other) const;
|
||||
|
||||
// assgin value to this tensor
|
||||
Tensor &AssignValue(const Tensor &tensor);
|
||||
|
||||
bool operator==(const Value &other) const override {
|
||||
if (other.isa<Tensor>()) {
|
||||
auto other_ = static_cast<const Tensor &>(other);
|
||||
|
|
|
@ -203,6 +203,8 @@ class Parameter:
|
|||
return self.default_input / other
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
default_input = self.default_input
|
||||
default_input[index] = value
|
||||
return self
|
||||
|
||||
def set_parameter_data(self, data):
|
||||
|
|
|
@ -150,6 +150,8 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
out = tensor_operator_registry.get('__setitem__')(self, index, value)
|
||||
self.assign_value(out)
|
||||
return self
|
||||
|
||||
def __gt__(self, other):
|
||||
|
|
|
@ -26,7 +26,7 @@ hyper_map = base.HyperMap()
|
|||
pack = P.Pack(axis=-1)
|
||||
|
||||
|
||||
def broadcast(broadcast_shape, x):
|
||||
def _broadcast(broadcast_shape, x):
|
||||
"""Broadcast tensor to the required shape."""
|
||||
if F.shape(x) == broadcast_shape:
|
||||
return x
|
||||
|
@ -36,13 +36,13 @@ def broadcast(broadcast_shape, x):
|
|||
return x
|
||||
|
||||
|
||||
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
|
||||
def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
|
||||
"""Transform indexing tensor to the required."""
|
||||
x = broadcast(broadcast_shape, x)
|
||||
return broadcast(final_shape, F.reshape(x, new_shape))
|
||||
x = _broadcast(broadcast_shape, x)
|
||||
return _broadcast(final_shape, F.reshape(x, new_shape))
|
||||
|
||||
|
||||
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
|
||||
|
@ -52,26 +52,31 @@ def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
|||
if check_dtypes:
|
||||
shape_tuple = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
|
||||
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
|
||||
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
|
||||
indices = pack(broadcast_tensors)
|
||||
return indices
|
||||
|
||||
|
||||
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
int_positions = const_utils.get_pos_of_int_index(indexes_types)
|
||||
for i in int_positions:
|
||||
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
tuple_index_new = ()
|
||||
tuple_len = len(tuple_index)
|
||||
for i in range(tuple_len):
|
||||
if i in int_positions:
|
||||
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
||||
else:
|
||||
tuple_index_new = tuple_index_new + (tuple_index[i],)
|
||||
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
||||
tensor_positions, slice_positions, ellipsis_position = \
|
||||
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
||||
tensor_indexes = []
|
||||
slice_indexes = []
|
||||
for i in tensor_positions:
|
||||
tensor_indexes.append(tuple_index[i])
|
||||
tensor_indexes.append(tuple_index_new[i])
|
||||
for j in slice_positions:
|
||||
slice_indexes.append(tuple_index[j])
|
||||
slice_indexes.append(tuple_index_new[j])
|
||||
data_shape = F.shape(data)
|
||||
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
||||
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
||||
|
@ -85,14 +90,14 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|||
|
||||
slice_number = 0
|
||||
final_index_tensors = []
|
||||
tuple_index_size = len(tuple_index)
|
||||
tuple_index_size = len(tuple_index_new)
|
||||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||
for i in range(tuple_index_size):
|
||||
if i in tensor_positions:
|
||||
transform_tensor = transform_indexing_tensor(broadcast_shape,
|
||||
final_shape,
|
||||
index_tensor_new_shape,
|
||||
tuple_index[i])
|
||||
transform_tensor = _transform_indexing_tensor(broadcast_shape,
|
||||
final_shape,
|
||||
index_tensor_new_shape,
|
||||
tuple_index_new[i])
|
||||
final_index_tensors.append(transform_tensor)
|
||||
if i in slice_positions:
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
|
||||
|
@ -114,7 +119,7 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|||
return indices
|
||||
|
||||
|
||||
def generate_updates_from_scalar(data, indices, value, op_type):
|
||||
def _generate_updates_from_scalar(data, indices, value, op_type):
|
||||
"""Generate an updates tensor from a scalar."""
|
||||
data_shape = F.shape(data)
|
||||
indices_shape = F.shape(indices)
|
||||
|
@ -122,7 +127,7 @@ def generate_updates_from_scalar(data, indices, value, op_type):
|
|||
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tuple(data, index, value, op_type):
|
||||
def _generate_updates_from_tuple(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tuple."""
|
||||
value_types = hyper_map(F.typeof, value)
|
||||
data_dtype = F.dtype(data)
|
||||
|
@ -132,14 +137,14 @@ def generate_updates_from_tuple(data, index, value, op_type):
|
|||
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
|
||||
if shapes_same:
|
||||
value = F.pack(value)
|
||||
return generate_updates_from_tensor(data, index, value, op_type)
|
||||
return _generate_updates_from_tensor(data, index, value, op_type)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tensor(data, index, value, op_type):
|
||||
def _generate_updates_from_tensor(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tensor."""
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
|
@ -152,45 +157,47 @@ def generate_updates_from_tensor(data, index, value, op_type):
|
|||
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
|
||||
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
|
||||
if need_broadcast:
|
||||
return broadcast(updates_shape, value)
|
||||
return _broadcast(updates_shape, value)
|
||||
return value
|
||||
|
||||
|
||||
def tensor_getitem(self, index):
|
||||
def _tensor_getitem(self, index):
|
||||
"""Handle tensor getitem"""
|
||||
if isinstance(index, Tensor):
|
||||
return tensor_index_by_tensor(self, index)
|
||||
if isinstance(index, tuple):
|
||||
return tensor_index_by_tuple(self, index)
|
||||
if isinstance(index, int):
|
||||
return tensor_index_by_integer(self, index)
|
||||
return _tensor_index_by_integer(self, index)
|
||||
if isinstance(index, slice):
|
||||
return tensor_index_by_slice(self, index)
|
||||
if isinstance(index, bool):
|
||||
return tensor_index_by_bool(self, index)
|
||||
return _tensor_index_by_bool(self, index)
|
||||
if index is None:
|
||||
return F.expand_dims(self, 0)
|
||||
if index is ...:
|
||||
return self
|
||||
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
|
||||
f"got {index} with type {type(index)}.")
|
||||
|
||||
|
||||
tensor_operator_registry.register("__getitem__", tensor_getitem)
|
||||
tensor_operator_registry.register("__getitem__", _tensor_getitem)
|
||||
|
||||
|
||||
def tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of tensor."""
|
||||
indices = generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
@ -204,7 +211,7 @@ def tensor_index_by_slice(data, slice_index):
|
|||
return F.strided_slice(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_integer(data, number):
|
||||
def _tensor_index_by_integer(data, number):
|
||||
"""Tensor getitem by a single integer number"""
|
||||
shape = F.shape(data)
|
||||
if not shape:
|
||||
|
@ -214,7 +221,7 @@ def tensor_index_by_integer(data, number):
|
|||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_bool(data, bool_value):
|
||||
def _tensor_index_by_bool(data, bool_value):
|
||||
"""Tensor getitem by a single bool value"""
|
||||
if bool_value:
|
||||
return F.expand_dims(data, 0)
|
||||
|
@ -225,9 +232,9 @@ def tensor_index_by_number(data, number):
|
|||
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
||||
number_type = const_utils.check_number_index_type(number)
|
||||
if number_type == const_utils.BOOL_:
|
||||
return tensor_index_by_bool(data, number)
|
||||
return _tensor_index_by_bool(data, number)
|
||||
if number_type == const_utils.INT_:
|
||||
return tensor_index_by_integer(data, number)
|
||||
return _tensor_index_by_integer(data, number)
|
||||
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")
|
||||
|
||||
|
||||
|
@ -241,7 +248,7 @@ def tensor_index_by_tensor(data, tensor_index):
|
|||
"the index tensor data type only support mstype.int32.")
|
||||
|
||||
|
||||
def tensor_index_by_tuple_slice(data, t):
|
||||
def _tensor_index_by_tuple_slice(data, t):
|
||||
"""Tensor getitem by a tuple of slice"""
|
||||
shape = F.shape(data)
|
||||
if len(t) > len(shape):
|
||||
|
@ -257,7 +264,303 @@ def tensor_index_by_tuple(data, tuple_index):
|
|||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_index_by_tuple_slice(data, tuple_index)
|
||||
return _tensor_index_by_tuple_slice(data, tuple_index)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
return tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
|
||||
|
||||
def _tensor_setitem(self, index, value):
|
||||
"""Handle tensor getitem"""
|
||||
if isinstance(index, Tensor):
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return tensor_setitem_by_tensor_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_tensor_with_tensor(self, index, value)
|
||||
if isinstance(value, tuple):
|
||||
return tensor_setitem_by_tensor_with_tuple(self, index, value)
|
||||
if isinstance(index, tuple):
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return tensor_setitem_by_tuple_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_tuple_with_tensor(self, index, value)
|
||||
if isinstance(value, tuple):
|
||||
return tensor_setitem_by_tuple_with_tuple(self, index, value)
|
||||
if isinstance(index, int):
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return tensor_setitem_by_number_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_number_with_tensor(self, index, value)
|
||||
if isinstance(index, slice):
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return tensor_setitem_by_slice_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_slice_with_tensor(self, index, value)
|
||||
if isinstance(index, bool):
|
||||
return _tensor_index_by_bool(self, index)
|
||||
if index is ...:
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return tensor_setitem_by_ellipsis_with_number(self, index, value)
|
||||
if isinstance(value, Tensor):
|
||||
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
|
||||
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
|
||||
and tensor with int32, got {} with type{}".format(index, type(index)))
|
||||
|
||||
|
||||
tensor_operator_registry.register("__setitem__", _tensor_setitem)
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a tensor."""
|
||||
updates = _generate_updates_from_tensor(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
index = F.expand_dims(index, -1)
|
||||
return P.TensorScatterUpdate()(data, index, updates)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
index_shape = F.shape(index)
|
||||
data_shape = F.shape(data)
|
||||
data_shape = const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value)
|
||||
size = const_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
||||
"""setitem by tensor index(dtype is int or bool) with tensor as value"""
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a scalar."""
|
||||
index_shape = F.shape(index)
|
||||
shape = F.shape(data)
|
||||
shape = const_utils.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
return F.select(index, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a scalar."""
|
||||
updates = _generate_updates_from_scalar(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
index = F.expand_dims(index, -1)
|
||||
return P.TensorScatterUpdate()(data, index, updates)
|
||||
|
||||
|
||||
def tensor_setitem_by_tensor_with_number(data, index, value):
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.BOOL_:
|
||||
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
|
||||
if tensor_dtype == const_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
|
||||
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int")
|
||||
|
||||
|
||||
def tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||
"""Assigns the tensor by tensor with tuple value."""
|
||||
index_dtype = F.dtype(index)
|
||||
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if check_dtype:
|
||||
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||
"""Assigns a scalar value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||
"""Set a tensor item by a tensor with a tuple."""
|
||||
updates = _generate_updates_from_tuple(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
index = F.expand_dims(index, -1)
|
||||
result = P.TensorScatterUpdate()(data, index, updates)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
||||
"""Givens a scalar assign to tensor by slice"""
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
||||
"""Assigns the tensor by tuple with number value."""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_setitem_by_slice_with_number(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = _generate_updates_from_scalar(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return P.TensorScatterUpdate()(data, indices, updates)
|
||||
|
||||
|
||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||
"""Assigns a tensor value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = const_utils.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
||||
"""Assigns the tensor by tuple with tensor value."""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_setitem_by_slice_with_tensor(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = _generate_updates_from_tensor(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return P.TensorScatterUpdate()(data, indices, updates)
|
||||
|
||||
|
||||
def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
||||
"""Assigns the tensor by tuple with tuple of value."""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = _generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = _generate_updates_from_tuple(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return P.TensorScatterUpdate()(data, indices, updates)
|
||||
|
||||
|
||||
def tensor_setitem_by_number_with_number(data, index, value):
|
||||
"""Assigns the tensor by number with number value."""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_number_with_tensor(data, index, value):
|
||||
"""Assigns the tensor by number with tensor value."""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_ellipsis_with_number(data, index, value):
|
||||
"""Assigns the tensor by ellipsis with number value."""
|
||||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
return F.fill(data_dtype, data_shape, value)
|
||||
|
||||
|
||||
def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
|
||||
"""Assigns the tensor by ellipsis with tensor value."""
|
||||
result = None
|
||||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
data_size = F.size(data)
|
||||
value_shape = F.shape(value)
|
||||
value_size = F.size(value)
|
||||
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
if check_result:
|
||||
if data_size == value_size:
|
||||
result = F.reshape(value, data_shape)
|
||||
result = F.cast(result, data_dtype)
|
||||
elif value_size == 1:
|
||||
param1 = F.fill(data_dtype, data_shape, 1)
|
||||
param2 = F.cast(value, data_dtype)
|
||||
result = F.tensor_mul(param1, param2)
|
||||
return result
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
"""Implementation for setitem."""
|
||||
|
||||
from . import _compile_utils as compile_utils
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ... import functional as F
|
||||
from ...composite import base
|
||||
from ....common import dtype as mstype
|
||||
|
||||
setitem = base.MultitypeFuncGraph('setitem')
|
||||
|
||||
|
@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Number")
|
||||
|
@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.BOOL_:
|
||||
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
|
||||
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
|
||||
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Number")
|
||||
|
@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_assgin_number(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_scalar(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||
|
@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_assgin_tensor(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_tensor(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tuple")
|
||||
|
@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_tuple(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Tuple")
|
||||
|
@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if check_dtype:
|
||||
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||
return result
|
||||
return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Slice", "Tensor")
|
||||
|
@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_tensor(data, input_slice, value)
|
||||
return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Slice", "Number")
|
||||
|
@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_number(data, input_slice, value)
|
||||
|
||||
|
||||
def _tensor_assgin_number(data, input_slice, value):
|
||||
"""Givens a scalar assign to tensor by slice"""
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Number", "Number")
|
||||
def _tensor_setitem_with_int_v1(data, index, value):
|
||||
"""Syntax: A[1] = 3"""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||
return compile_utils.tensor_setitem_by_number_with_number(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Number", "Tensor")
|
||||
def _tensor_setitem_with_int_v2(data, index, value):
|
||||
"""Syntax: A[1] = Tensor"""
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||
return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Ellipsis", "Number")
|
||||
def _tensor_setitem_with_ellipsis_v1(data, index, value):
|
||||
"""Syntax: A[...] = number."""
|
||||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
return F.fill(data_dtype, data_shape, value)
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Ellipsis", "Tensor")
|
||||
def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
||||
"""Syntax: A[...] = Tensor."""
|
||||
result = None
|
||||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
data_size = F.size(data)
|
||||
value_shape = F.shape(value)
|
||||
value_size = F.size(value)
|
||||
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
if check_result:
|
||||
if data_size == value_size:
|
||||
result = F.reshape(value, data_shape)
|
||||
result = F.cast(result, data_dtype)
|
||||
elif value_size == 1:
|
||||
param1 = F.fill(data_dtype, data_shape, 1)
|
||||
param2 = F.cast(value, data_dtype)
|
||||
result = F.tensor_mul(param1, param2)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_assgin_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
||||
"""Assigns a tensor value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = const_utils.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||
"""Assigns a scalar value to the tensor."""
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
condition = F.cast(condition, mstype.bool_)
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
return F.select(condition, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||
"""Set a tensor item by a tensor with a tuple."""
|
||||
updates = compile_utils.generate_updates_from_tuple(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
result = F.scatter_update(data, index, updates)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a scalar."""
|
||||
updates = compile_utils.generate_updates_from_scalar(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a scalar."""
|
||||
index_shape = F.shape(index)
|
||||
shape = F.shape(data)
|
||||
shape = const_utils.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
return F.select(index, u, data)
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a tensor."""
|
||||
updates = compile_utils.generate_updates_from_tensor(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
index_shape = F.shape(index)
|
||||
data_shape = F.shape(data)
|
||||
data_shape = const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value)
|
||||
size = const_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value)
|
||||
|
|
|
@ -20,10 +20,14 @@ from mindspore import Tensor, Parameter
|
|||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class NetWorkSlicePositive(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSlicePositive, self).__init__()
|
||||
|
@ -139,7 +143,7 @@ class TensorGetItemByThreeTensors(Cell):
|
|||
return ret0, ret1, ret2
|
||||
|
||||
|
||||
def Xtest_getitem_by_tensors():
|
||||
def test_getitem_by_tensors():
|
||||
net = TensorGetItemByThreeTensors()
|
||||
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
|
||||
|
@ -155,119 +159,140 @@ def Xtest_getitem_by_tensors():
|
|||
assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_0(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
|
||||
class TensorGetItemByMixedTensorsBasicCase(Cell):
|
||||
def __init__(self, c0, c1, c2, c3, c4, c5):
|
||||
super(TensorGetItemByMixedTensorsBasicCase, self).__init__()
|
||||
self.const0 = Tensor(c0)
|
||||
self.const1 = Tensor(c1)
|
||||
self.const2 = Tensor(c2)
|
||||
self.const3 = Tensor(c3)
|
||||
self.const4 = Tensor(c4)
|
||||
self.const5 = Tensor(c5)
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
|
||||
return ret
|
||||
ret0 = tensor[index_0, index_1, 0:3] + self.const0
|
||||
ret1 = tensor[0:3, index_0, ...] + self.const1
|
||||
ret2 = tensor[0, index_0, index_1] + self.const2
|
||||
ret3 = tensor[..., index_0, 0:3] + self.const3
|
||||
ret4 = tensor[0:2, index_0, index_1] + self.const4
|
||||
ret5 = tensor[..., index_0, index_1] + self.const5
|
||||
return ret0, ret1, ret2, ret3, ret4, ret5
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_1(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0, index_0, index_1, ..., 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_3(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_3, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_4(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_4, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_5(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_5, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_6(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_6, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
|
||||
return ret
|
||||
def test_getitem_by_mixed_tensors():
|
||||
const0 = np.ones((3, 4, 5, 3), np.float32)
|
||||
const1 = np.ones((3, 3, 4, 5, 5), np.float32)
|
||||
const2 = np.ones((3, 4, 5), np.float32)
|
||||
const3 = np.ones((3, 3, 4, 5, 3), np.float32)
|
||||
const4 = np.ones((2, 3, 4, 5), np.float32)
|
||||
const5 = np.ones((3, 3, 4, 5), np.float32)
|
||||
net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5)
|
||||
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
|
||||
input_ms = Tensor(input_np, mstype.float32)
|
||||
index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32)
|
||||
index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32)
|
||||
index_0 = Tensor(index_np_0, mstype.int32)
|
||||
index_1 = Tensor(index_np_1, mstype.int32)
|
||||
out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1)
|
||||
assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0))
|
||||
assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1))
|
||||
assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2))
|
||||
assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3))
|
||||
assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4))
|
||||
assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5))
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_0(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
|
||||
self.const = Tensor(np.ones((3, 4, 5), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)),
|
||||
mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
|
||||
self.param[0:2, index_0, index_1] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_mixed_tensors_0():
|
||||
value = 88.0
|
||||
net = TensorSetItemByMixedTensors_0(value)
|
||||
index_0 = np.random.randint(3, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(4, size=(4, 5))
|
||||
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
|
||||
const = np.ones((3, 4, 5), np.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms)
|
||||
input_np[0:2, index_0, index_1] = value
|
||||
assert np.all(out.asnumpy() == (input_np + const))
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_1(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
self.const = Tensor(np.ones((3, 4, 5), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
|
||||
self.param[0:2, index_0, ...] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_mixed_tensors_1():
|
||||
value = 88.0
|
||||
net = TensorSetItemByMixedTensors_1(value)
|
||||
index_0 = np.random.randint(3, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(4, size=(4, 5))
|
||||
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
|
||||
const = np.ones((3, 4, 5), np.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms)
|
||||
input_np[0:2, index_0, ...] = value
|
||||
assert np.all(out.asnumpy() == (input_np + const))
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_2(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
|
||||
self.const = Tensor(np.ones((3, 4, 5), np.float16))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[..., index_0, index_1, index_2, 3] = self.value
|
||||
self.param[..., index_0, 1] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_mixed_tensors_2():
|
||||
value = 88.0
|
||||
net = TensorSetItemByMixedTensors_2(value)
|
||||
index_0 = np.random.randint(3, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(4, size=(4, 5))
|
||||
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
|
||||
const = np.ones((3, 4, 5), np.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms)
|
||||
input_np[..., index_0, 1] = value
|
||||
assert np.all(out.asnumpy() == (input_np + const))
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsTypeError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
|
||||
|
@ -277,13 +302,13 @@ class TensorGetItemByMixedTensorsTypeError(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
|
||||
return ret
|
||||
def test_getitem_by_mixedtensor_exception():
|
||||
input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
|
||||
index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
|
||||
index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
|
||||
net1 = TensorGetItemByMixedTensorsTypeError()
|
||||
with pytest.raises(TypeError):
|
||||
net1(input_ms, index_0, index_1)
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithNumber(Cell):
|
||||
|
@ -299,6 +324,18 @@ class TensorSetItemByOneTensorWithNumber(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_one_tensor_with_number():
|
||||
value = 0.0
|
||||
net = TensorSetItemByOneTensorWithNumber(value)
|
||||
index_np = np.random.randint(4, size=(5, 4))
|
||||
index = Tensor(index_np, mstype.int32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
out = net(index)
|
||||
input_data[index_np] = value
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTensor, self).__init__()
|
||||
|
@ -311,6 +348,19 @@ class TensorSetItemByOneTensorWithTensor(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_one_tensor_with_tensor():
|
||||
net = TensorSetItemByOneTensorWithTensor()
|
||||
index_np = np.random.randint(4, size=(5, 4))
|
||||
index = Tensor(index_np, mstype.int32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
value = np.zeros((4, 7, 8)).astype(np.float32)
|
||||
value_ms = Tensor(value, mstype.float32)
|
||||
out = net(index, value_ms)
|
||||
input_data[index_np] = value
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
|
||||
|
@ -324,6 +374,18 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_one_tensor_with_tuple_number():
|
||||
value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)
|
||||
net = TensorSetItemByOneTensorWithTupleOfNumber(value)
|
||||
input_np = np.random.randint(5, size=(5, 4))
|
||||
input_ms = Tensor(input_np, mstype.int32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
out = net(input_ms)
|
||||
input_data[input_np] = value
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
|
||||
|
@ -336,6 +398,23 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_one_tensor_with_tuple_tensors():
|
||||
net = TensorSetItemByOneTensorWithTupleOfTensor()
|
||||
input_np = np.random.randint(6, size=(5, 4)).astype(np.int32)
|
||||
input_ms = Tensor(input_np, mstype.int32)
|
||||
input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32)
|
||||
value_0_np = np.zeros((8,), np.float32)
|
||||
value_1_np = np.ones((8,), np.float32)
|
||||
value_2_np = np.ones((8,), np.float32)*2
|
||||
value_0 = Tensor(value_0_np)
|
||||
value_1 = Tensor(value_1_np)
|
||||
value_2 = Tensor(value_2_np)
|
||||
const = np.ones((6, 3, 8)).astype(np.float32)
|
||||
out = net(input_ms, value_0, value_1, value_2)
|
||||
input_data[input_np] = (value_0_np, value_1_np, value_2_np)
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithNumber, self).__init__()
|
||||
|
@ -349,6 +428,22 @@ class TensorSetItemByTensorsWithNumber(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_tensors_with_number():
|
||||
value = 0.0
|
||||
net = TensorSetItemByTensorsWithNumber(value)
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(7, size=(4, 5))
|
||||
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms)
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
|
||||
input_data[index_0, index_1, index_2] = value
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensor, self).__init__()
|
||||
|
@ -361,6 +456,23 @@ class TensorSetItemByTensorsWithTensor(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_tensors_with_tensor():
|
||||
net = TensorSetItemByTensorsWithTensor()
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(7, size=(4, 5))
|
||||
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
|
||||
value = np.zeros((4, 5)).astype(np.float32)
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
value_ms = Tensor(value, mstype.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms, value_ms)
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
|
||||
input_data[index_0, index_1, index_2] = value
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
|
||||
|
@ -373,6 +485,17 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_tensors_with_tensor_error():
|
||||
index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
|
||||
index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
|
||||
index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
|
||||
index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32)
|
||||
value = Tensor(np.zeros((2, 5)), mstype.float32)
|
||||
net = TensorSetItemByTensorsWithTensorNumberError()
|
||||
with pytest.raises(IndexError):
|
||||
net(index_0, index_1, index_2, index_3, value)
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
|
||||
|
@ -386,6 +509,22 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_tensors_with_tuple_of_number():
|
||||
value = (0.0, 1.1, 2.2, 3.3, 4.4)
|
||||
net = TensorSetItemByTensorsWithTupleOfNumber(value)
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(7, size=(4, 5))
|
||||
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
|
||||
input_data[index_0, index_1, index_2] = value
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms)
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
|
||||
|
@ -398,6 +537,27 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
def test_setitem_by_tensors_with_tuple_of_tensor():
|
||||
value_0 = np.zeros((4, 5))
|
||||
value_1 = np.ones((4, 5))
|
||||
value_2 = np.ones((4, 5)) * 2
|
||||
value_0_ms = Tensor(value_0, mstype.float32)
|
||||
value_1_ms = Tensor(value_1, mstype.float32)
|
||||
value_2_ms = Tensor(value_2, mstype.float32)
|
||||
net = TensorSetItemByTensorsWithTupleOfTensor()
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5))
|
||||
index_1 = np.random.randint(7, size=(4, 5))
|
||||
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
|
||||
index_0_ms = Tensor(index_0, mstype.int32)
|
||||
index_1_ms = Tensor(index_1, mstype.int32)
|
||||
index_2_ms = Tensor(index_2, mstype.int32)
|
||||
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
|
||||
input_data[index_0, index_1, index_2] = (value_0, value_1, value_2)
|
||||
const = np.ones((6, 7, 8)).astype(np.float32)
|
||||
out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms)
|
||||
assert np.all(out.asnumpy() == (input_data + const))
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
|
||||
|
@ -410,17 +570,44 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByMixedTensors, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = 99.0
|
||||
def test_setitem_by_tensor_with_tuple_of_tensor_error():
|
||||
net = TensorSetItemByTensorsWithTupleOfTensorNumberError()
|
||||
index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
|
||||
index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
|
||||
index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
|
||||
value_0 = np.zeros((4, 5))
|
||||
value_1 = np.ones((4, 5))
|
||||
value_0_ms = Tensor(value_0, mstype.float32)
|
||||
value_1_ms = Tensor(value_1, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms)
|
||||
|
||||
def construct(self, index_0, index_1):
|
||||
self.param[index_0, index_1, 0:6] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
def test_setitem_grad():
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(
|
||||
Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True)
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3:1, ::] = b
|
||||
c = a + self.weight
|
||||
return c
|
||||
|
||||
class GradNet(Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, x, y, sens):
|
||||
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
|
||||
net = GradNet(Net())
|
||||
x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
|
||||
y = Tensor(np.array([3]).astype(np.float32), mstype.float32)
|
||||
sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
|
||||
net(x, y, sens)
|
||||
|
||||
|
||||
class TensorAssignWithSliceError1(Cell):
|
||||
|
@ -475,7 +662,6 @@ class TensorAssignWithSlice(Cell):
|
|||
|
||||
|
||||
def test_tensor_assign():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = TensorAssignWithSlice()
|
||||
net2 = TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
|
@ -621,7 +807,7 @@ class TensorAssignWithTupleInteger(Cell):
|
|||
class TensorAssignWithBoolTensorIndex(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
|
||||
self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
self.u_scalar = 5
|
||||
|
||||
def construct(self, a, b, c, u_tensor):
|
||||
|
@ -643,8 +829,7 @@ class TensorAssignWithBoolTensorIndexError(Cell):
|
|||
class TensorAssignWithBoolTensorIndex2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
|
||||
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
|
||||
self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
||||
self.u_scalar = 5
|
||||
|
||||
def construct(self, a, u_tensor):
|
||||
|
@ -666,7 +851,40 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
|
|||
return a
|
||||
|
||||
|
||||
def Xtest_tensor_assign_bool_index():
|
||||
def test_tensor_assign_bool_index_0():
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
b = a > 5
|
||||
c = a < 3
|
||||
Ta = Tensor(a, dtype=mstype.float32)
|
||||
Tb = Tensor(b)
|
||||
Tc = Tensor(c)
|
||||
u_tensor = Tensor([1], dtype=mstype.float32)
|
||||
net1 = TensorAssignWithBoolTensorIndex()
|
||||
out = net1(Ta, Tb, Tc, u_tensor)
|
||||
res = np.arange(60).reshape(3, 4, 5)
|
||||
res[c] = 5
|
||||
res[b] = 1
|
||||
res = res + np.ones([3, 4, 5])
|
||||
assert np.all(out.asnumpy() == res)
|
||||
|
||||
|
||||
def test_tensor_assign_bool_index_1():
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
Ta = Tensor(a, dtype=mstype.float32)
|
||||
u_tensor = Tensor([1], dtype=mstype.float32)
|
||||
net2 = TensorAssignWithBoolTensorIndex2()
|
||||
out = net2(Ta, u_tensor)
|
||||
res = np.arange(60).reshape(3, 4, 5)
|
||||
res[res > 8] = 1
|
||||
res[res >= 6] = 5
|
||||
res[res < 3] = 5
|
||||
res[res <= 5] = 1
|
||||
res[res == 5] = 5
|
||||
res = res + np.ones([3, 4, 5])
|
||||
assert np.all(out.asnumpy() == res)
|
||||
|
||||
|
||||
def test_tensor_assign_bool_index_exception():
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
b = a > 5
|
||||
c = a < 3
|
||||
|
@ -679,8 +897,6 @@ def Xtest_tensor_assign_bool_index():
|
|||
u_scalar = 5
|
||||
net1 = TensorAssignWithBoolTensorIndex()
|
||||
net2 = TensorAssignWithBoolTensorIndex2()
|
||||
net1(Ta, Tb, Tc, u_tensor)
|
||||
net1(Ta, Tb, Tc, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Td, Tc, u_tensor)
|
||||
with pytest.raises(IndexError):
|
||||
|
@ -695,14 +911,14 @@ def Xtest_tensor_assign_bool_index():
|
|||
with pytest.raises(ValueError):
|
||||
net2(Ta, u_tensor_error)
|
||||
net3 = TensorAssignWithBoolTensorIndexError()
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(IndexError):
|
||||
net3(Ta, Tb, Tc, u_tensor)
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(IndexError):
|
||||
net3(Ta, Tb, Tc, u_scalar)
|
||||
net4 = TensorAssignWithBoolTensorIndex2Error()
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(IndexError):
|
||||
net4(Ta, u_tensor)
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(IndexError):
|
||||
net4(Ta, u_scalar)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue