add support for parameter

support for tensor setitem

add support for tensor assgin
This commit is contained in:
huangdongrun 2020-06-12 16:08:54 +08:00
parent c55b81e94f
commit 79058d3509
7 changed files with 699 additions and 360 deletions

View File

@ -92,6 +92,10 @@ Tensor &Tensor::operator=(const Tensor &tensor) {
} }
return *this; return *this;
} }
Tensor &Tensor::AssignValue(const Tensor &tensor) {
*this = tensor;
return *this;
}
bool Tensor::operator==(const Tensor &tensor) const { bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_); return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
@ -470,6 +474,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.set_dtype(mindspore.int32) >>> data.set_dtype(mindspore.int32)
mindspore.int32 mindspore.int32
)mydelimiter") )mydelimiter")
.def("assign_value", &Tensor::AssignValue, R"mydelimiter(
Assign another tensor value to this.
Arg:
value (:class:`mindspore.tensor`): The value tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((1, 2), np.float32))
>>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32))
>>> data.assign_value(data2)
>>> data.shape
(2, 2)
)mydelimiter")
.def("__str__", &Tensor::ToString) .def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle( .def(py::pickle(

View File

@ -173,6 +173,9 @@ class Tensor : public MetaTensor {
// It is different from 'operator==' which just compare shape/type/address, it do real value comparison. // It is different from 'operator==' which just compare shape/type/address, it do real value comparison.
bool ValueEqual(const Tensor &other) const; bool ValueEqual(const Tensor &other) const;
// assgin value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override { bool operator==(const Value &other) const override {
if (other.isa<Tensor>()) { if (other.isa<Tensor>()) {
auto other_ = static_cast<const Tensor &>(other); auto other_ = static_cast<const Tensor &>(other);

View File

@ -203,6 +203,8 @@ class Parameter:
return self.default_input / other return self.default_input / other
def __setitem__(self, index, value): def __setitem__(self, index, value):
default_input = self.default_input
default_input[index] = value
return self return self
def set_parameter_data(self, data): def set_parameter_data(self, data):

View File

@ -150,6 +150,8 @@ class Tensor(Tensor_):
return out return out
def __setitem__(self, index, value): def __setitem__(self, index, value):
out = tensor_operator_registry.get('__setitem__')(self, index, value)
self.assign_value(out)
return self return self
def __gt__(self, other): def __gt__(self, other):

View File

@ -26,7 +26,7 @@ hyper_map = base.HyperMap()
pack = P.Pack(axis=-1) pack = P.Pack(axis=-1)
def broadcast(broadcast_shape, x): def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape.""" """Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape: if F.shape(x) == broadcast_shape:
return x return x
@ -36,13 +36,13 @@ def broadcast(broadcast_shape, x):
return x return x
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required.""" """Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x) x = _broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape)) return _broadcast(final_shape, F.reshape(x, new_shape))
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor.""" """Generate an indices tensor from a tuple of tensor."""
indices = None indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
@ -52,26 +52,31 @@ def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
if check_dtypes: if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index) shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors) indices = pack(broadcast_tensors)
return indices return indices
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types) int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions: tuple_index_new = ()
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) tuple_len = len(tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index) for i in range(tuple_len):
if i in int_positions:
tuple_index_new = tuple_index_new + (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
else:
tuple_index_new = tuple_index_new + (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \ tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name) const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = [] tensor_indexes = []
slice_indexes = [] slice_indexes = []
for i in tensor_positions: for i in tensor_positions:
tensor_indexes.append(tuple_index[i]) tensor_indexes.append(tuple_index_new[i])
for j in slice_positions: for j in slice_positions:
slice_indexes.append(tuple_index[j]) slice_indexes.append(tuple_index_new[j])
data_shape = F.shape(data) data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
@ -85,14 +90,14 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
slice_number = 0 slice_number = 0
final_index_tensors = [] final_index_tensors = []
tuple_index_size = len(tuple_index) tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size): for i in range(tuple_index_size):
if i in tensor_positions: if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape, transform_tensor = _transform_indexing_tensor(broadcast_shape,
final_shape, final_shape,
index_tensor_new_shape, index_tensor_new_shape,
tuple_index[i]) tuple_index_new[i])
final_index_tensors.append(transform_tensor) final_index_tensors.append(transform_tensor)
if i in slice_positions: if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
@ -114,7 +119,7 @@ def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
return indices return indices
def generate_updates_from_scalar(data, indices, value, op_type): def _generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar.""" """Generate an updates tensor from a scalar."""
data_shape = F.shape(data) data_shape = F.shape(data)
indices_shape = F.shape(indices) indices_shape = F.shape(indices)
@ -122,7 +127,7 @@ def generate_updates_from_scalar(data, indices, value, op_type):
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type): def _generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple.""" """Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value) value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
@ -132,14 +137,14 @@ def generate_updates_from_tuple(data, index, value, op_type):
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same: if shapes_same:
value = F.pack(value) value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type) return _generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data) data_shape = F.shape(data)
index_shape = F.shape(index) index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type): def _generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor.""" """Generate an updates tensor from a tensor."""
data_shape = F.shape(data) data_shape = F.shape(data)
index_shape = F.shape(index) index_shape = F.shape(index)
@ -152,45 +157,47 @@ def generate_updates_from_tensor(data, index, value, op_type):
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast: if need_broadcast:
return broadcast(updates_shape, value) return _broadcast(updates_shape, value)
return value return value
def tensor_getitem(self, index): def _tensor_getitem(self, index):
"""Handle tensor getitem""" """Handle tensor getitem"""
if isinstance(index, Tensor): if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index) return tensor_index_by_tensor(self, index)
if isinstance(index, tuple): if isinstance(index, tuple):
return tensor_index_by_tuple(self, index) return tensor_index_by_tuple(self, index)
if isinstance(index, int): if isinstance(index, int):
return tensor_index_by_integer(self, index) return _tensor_index_by_integer(self, index)
if isinstance(index, slice): if isinstance(index, slice):
return tensor_index_by_slice(self, index) return tensor_index_by_slice(self, index)
if isinstance(index, bool): if isinstance(index, bool):
return tensor_index_by_bool(self, index) return _tensor_index_by_bool(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...: if index is ...:
return self return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.") f"got {index} with type {type(index)}.")
tensor_operator_registry.register("__getitem__", tensor_getitem) tensor_operator_registry.register("__getitem__", _tensor_getitem)
def tensor_getitem_by_tuple_of_tensor(data, tuple_index): def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor.""" """Tensor getitem by a tuple of tensor."""
indices = generate_indices_from_tuple_of_tensor(data, indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index, tuple_index,
const_utils.TENSOR_GETITEM) const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices) result = F.gather_nd(data, indices)
return result return result
def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor.""" """Tensor getitem by a tuple of mixed tensor."""
indices = generate_indices_from_tuple_of_mixed_tensors(data, indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index, tuple_index,
const_utils.TENSOR_GETITEM) const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices) result = F.gather_nd(data, indices)
return result return result
@ -204,7 +211,7 @@ def tensor_index_by_slice(data, slice_index):
return F.strided_slice(data, begin_strides, end_strides, step_strides) return F.strided_slice(data, begin_strides, end_strides, step_strides)
def tensor_index_by_integer(data, number): def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number""" """Tensor getitem by a single integer number"""
shape = F.shape(data) shape = F.shape(data)
if not shape: if not shape:
@ -214,7 +221,7 @@ def tensor_index_by_integer(data, number):
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_index_by_bool(data, bool_value): def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value""" """Tensor getitem by a single bool value"""
if bool_value: if bool_value:
return F.expand_dims(data, 0) return F.expand_dims(data, 0)
@ -225,9 +232,9 @@ def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value""" """Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number) number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_: if number_type == const_utils.BOOL_:
return tensor_index_by_bool(data, number) return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_: if number_type == const_utils.INT_:
return tensor_index_by_integer(data, number) return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")
@ -241,7 +248,7 @@ def tensor_index_by_tensor(data, tensor_index):
"the index tensor data type only support mstype.int32.") "the index tensor data type only support mstype.int32.")
def tensor_index_by_tuple_slice(data, t): def _tensor_index_by_tuple_slice(data, t):
"""Tensor getitem by a tuple of slice""" """Tensor getitem by a tuple of slice"""
shape = F.shape(data) shape = F.shape(data)
if len(t) > len(shape): if len(t) > len(shape):
@ -257,7 +264,303 @@ def tensor_index_by_tuple(data, tuple_index):
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR: if index_elements_type == const_utils.NO_TENSOR:
return tensor_index_by_tuple_slice(data, tuple_index) return _tensor_index_by_tuple_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR: if index_elements_type == const_utils.ALL_TENSOR:
return tensor_getitem_by_tuple_of_tensor(data, tuple_index) return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))
tensor_operator_registry.register("__setitem__", _tensor_setitem)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = _generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result
def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
"""setitem by tensor index(dtype is int or bool) with tensor as value"""
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = _generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
def tensor_setitem_by_tensor_with_number(data, index, value):
index_dtype = F.dtype(index)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int")
def tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Assigns the tensor by tensor with tuple value."""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
result = P.TensorScatterUpdate()(data, index, updates)
return result
def tensor_setitem_by_slice_with_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_setitem_by_slice_with_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""Assigns the tensor by tuple with tuple of value."""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def tensor_setitem_by_number_with_number(data, index, value):
"""Assigns the tensor by number with number value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
def tensor_setitem_by_number_with_tensor(data, index, value):
"""Assigns the tensor by number with tensor value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
def tensor_setitem_by_ellipsis_with_number(data, index, value):
"""Assigns the tensor by ellipsis with number value."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
"""Assigns the tensor by ellipsis with tensor value."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result

View File

@ -16,10 +16,8 @@
"""Implementation for setitem.""" """Implementation for setitem."""
from . import _compile_utils as compile_utils from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F from ... import functional as F
from ...composite import base from ...composite import base
from ....common import dtype as mstype
setitem = base.MultitypeFuncGraph('setitem') setitem = base.MultitypeFuncGraph('setitem')
@ -139,11 +137,7 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value_tensor)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@setitem.register("Tensor", "Tensor", "Number") @setitem.register("Tensor", "Tensor", "Number")
@ -166,11 +160,7 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
@setitem.register("Tensor", "Tuple", "Number") @setitem.register("Tensor", "Tuple", "Number")
@ -191,24 +181,7 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tensor") @setitem.register("Tensor", "Tuple", "Tensor")
@ -229,24 +202,7 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tuple") @setitem.register("Tensor", "Tuple", "Tuple")
@ -268,22 +224,7 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tensor", "Tuple") @setitem.register("Tensor", "Tensor", "Tuple")
@ -299,12 +240,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
@setitem.register("Tensor", "Slice", "Tensor") @setitem.register("Tensor", "Slice", "Tensor")
@ -326,7 +262,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
return _tensor_assgin_tensor(data, input_slice, value) return compile_utils.tensor_setitem_by_slice_with_tensor(data, input_slice, value)
@setitem.register("Tensor", "Slice", "Number") @setitem.register("Tensor", "Slice", "Number")
@ -348,168 +284,28 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
return _tensor_assgin_number(data, input_slice, value) return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value)
def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
@setitem.register("Tensor", "Number", "Number") @setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value): def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3""" """Syntax: A[1] = 3"""
data_shape = F.shape(data) return compile_utils.tensor_setitem_by_number_with_number(data, index, value)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Number", "Tensor") @setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value): def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor""" """Syntax: A[1] = Tensor"""
data_shape = F.shape(data) return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Ellipsis", "Number") @setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value): def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number.""" """Syntax: A[...] = number."""
data_shape = F.shape(data) return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
@setitem.register("Tensor", "Ellipsis", "Tensor") @setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value): def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor.""" """Syntax: A[...] = Tensor."""
result = None return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value)
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
value_fill = F.tensor_mul(value_fill, value)
elif value_size > 1:
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = compile_utils.generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = compile_utils.generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
return F.select(index, u, data)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = compile_utils.generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
one_data = F.ones_like(data)
u = F.tensor_mul(one_data, u_cast)
result = F.select(index, u, data)
return result

View File

@ -20,10 +20,14 @@ from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
def setup_module(): def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
class NetWorkSlicePositive(Cell): class NetWorkSlicePositive(Cell):
def __init__(self): def __init__(self):
super(NetWorkSlicePositive, self).__init__() super(NetWorkSlicePositive, self).__init__()
@ -139,7 +143,7 @@ class TensorGetItemByThreeTensors(Cell):
return ret0, ret1, ret2 return ret0, ret1, ret2
def Xtest_getitem_by_tensors(): def test_getitem_by_tensors():
net = TensorGetItemByThreeTensors() net = TensorGetItemByThreeTensors()
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
@ -155,119 +159,140 @@ def Xtest_getitem_by_tensors():
assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5])) assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
class TensorGetItemByMixedTensors_0(Cell): class TensorGetItemByMixedTensorsBasicCase(Cell):
def __init__(self): def __init__(self, c0, c1, c2, c3, c4, c5):
super(TensorGetItemByMixedTensors_0, self).__init__() super(TensorGetItemByMixedTensorsBasicCase, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) self.const0 = Tensor(c0)
self.const1 = Tensor(c1)
self.const2 = Tensor(c2)
self.const3 = Tensor(c3)
self.const4 = Tensor(c4)
self.const5 = Tensor(c5)
def construct(self, tensor, index_0, index_1): def construct(self, tensor, index_0, index_1):
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const ret0 = tensor[index_0, index_1, 0:3] + self.const0
return ret ret1 = tensor[0:3, index_0, ...] + self.const1
ret2 = tensor[0, index_0, index_1] + self.const2
ret3 = tensor[..., index_0, 0:3] + self.const3
ret4 = tensor[0:2, index_0, index_1] + self.const4
ret5 = tensor[..., index_0, index_1] + self.const5
return ret0, ret1, ret2, ret3, ret4, ret5
class TensorGetItemByMixedTensors_1(Cell): def test_getitem_by_mixed_tensors():
def __init__(self): const0 = np.ones((3, 4, 5, 3), np.float32)
super(TensorGetItemByMixedTensors_1, self).__init__() const1 = np.ones((3, 3, 4, 5, 5), np.float32)
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) const2 = np.ones((3, 4, 5), np.float32)
const3 = np.ones((3, 3, 4, 5, 3), np.float32)
def construct(self, tensor, index_0, index_1): const4 = np.ones((2, 3, 4, 5), np.float32)
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const const5 = np.ones((3, 3, 4, 5), np.float32)
return ret net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
input_ms = Tensor(input_np, mstype.float32)
class TensorGetItemByMixedTensors_2(Cell): index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32)
def __init__(self): index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32)
super(TensorGetItemByMixedTensors_2, self).__init__() index_0 = Tensor(index_np_0, mstype.int32)
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) index_1 = Tensor(index_np_1, mstype.int32)
out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1)
def construct(self, tensor, index_0, index_1): assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0))
ret = tensor[0, index_0, index_1, ..., 3] + self.const assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1))
return ret assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2))
assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3))
assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4))
class TensorGetItemByMixedTensors_3(Cell): assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5))
def __init__(self):
super(TensorGetItemByMixedTensors_3, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_4(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_4, self).__init__()
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
return ret
class TensorGetItemByMixedTensors_5(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_5, self).__init__()
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
return ret
class TensorGetItemByMixedTensors_6(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_6, self).__init__()
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
return ret
class TensorSetItemByMixedTensors_0(Cell): class TensorSetItemByMixedTensors_0(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByMixedTensors_0, self).__init__() super(TensorSetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) self.const = Tensor(np.ones((3, 4, 5), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)),
mstype.float32), mstype.float32),
name="x") name="x")
self.value = value self.value = value
def construct(self, index_0, index_1, index_2): def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value self.param[0:2, index_0, index_1] = self.value
ret = self.param + self.const ret = self.param + self.const
return ret return ret
def test_setitem_by_mixed_tensors_0():
value = 88.0
net = TensorSetItemByMixedTensors_0(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[0:2, index_0, index_1] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorSetItemByMixedTensors_1(Cell): class TensorSetItemByMixedTensors_1(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByMixedTensors_1, self).__init__() super(TensorSetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) self.const = Tensor(np.ones((3, 4, 5), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
name="x") name="x")
self.value = value self.value = value
def construct(self, index_0, index_1, index_2): def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value self.param[0:2, index_0, ...] = self.value
ret = self.param + self.const ret = self.param + self.const
return ret return ret
def test_setitem_by_mixed_tensors_1():
value = 88.0
net = TensorSetItemByMixedTensors_1(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[0:2, index_0, ...] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorSetItemByMixedTensors_2(Cell): class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__() super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) self.const = Tensor(np.ones((3, 4, 5), np.float16))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16),
name="x") name="x")
self.value = value self.value = value
def construct(self, index_0, index_1, index_2): def construct(self, index_0, index_1, index_2):
self.param[..., index_0, index_1, index_2, 3] = self.value self.param[..., index_0, 1] = self.value
ret = self.param + self.const ret = self.param + self.const
return ret return ret
def test_setitem_by_mixed_tensors_2():
value = 88.0
net = TensorSetItemByMixedTensors_2(value)
index_0 = np.random.randint(3, size=(3, 4, 5))
index_1 = np.random.randint(4, size=(4, 5))
index_2 = np.random.randint(3, size=(2, 1, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
const = np.ones((3, 4, 5), np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
input_np[..., index_0, 1] = value
assert np.all(out.asnumpy() == (input_np + const))
class TensorGetItemByMixedTensorsTypeError(Cell): class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self): def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__() super(TensorGetItemByMixedTensorsTypeError, self).__init__()
@ -277,13 +302,13 @@ class TensorGetItemByMixedTensorsTypeError(Cell):
return ret return ret
class TensorGetItemByMixedTensorsNumberError(Cell): def test_getitem_by_mixedtensor_exception():
def __init__(self): input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
super(TensorGetItemByMixedTensorsNumberError, self).__init__() index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
def construct(self, x, index_0, index_1): net1 = TensorGetItemByMixedTensorsTypeError()
ret = x[index_0, index_1, 0:3, ..., index_1, index_0] with pytest.raises(TypeError):
return ret net1(input_ms, index_0, index_1)
class TensorSetItemByOneTensorWithNumber(Cell): class TensorSetItemByOneTensorWithNumber(Cell):
@ -299,6 +324,18 @@ class TensorSetItemByOneTensorWithNumber(Cell):
return ret return ret
def test_setitem_one_tensor_with_number():
value = 0.0
net = TensorSetItemByOneTensorWithNumber(value)
index_np = np.random.randint(4, size=(5, 4))
index = Tensor(index_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index)
input_data[index_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTensor(Cell): class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__() super(TensorSetItemByOneTensorWithTensor, self).__init__()
@ -311,6 +348,19 @@ class TensorSetItemByOneTensorWithTensor(Cell):
return ret return ret
def test_setitem_by_one_tensor_with_tensor():
net = TensorSetItemByOneTensorWithTensor()
index_np = np.random.randint(4, size=(5, 4))
index = Tensor(index_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
const = np.ones((6, 7, 8)).astype(np.float32)
value = np.zeros((4, 7, 8)).astype(np.float32)
value_ms = Tensor(value, mstype.float32)
out = net(index, value_ms)
input_data[index_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTupleOfNumber(Cell): class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
@ -324,6 +374,18 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
return ret return ret
def test_setitem_by_one_tensor_with_tuple_number():
value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)
net = TensorSetItemByOneTensorWithTupleOfNumber(value)
input_np = np.random.randint(5, size=(5, 4))
input_ms = Tensor(input_np, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(input_ms)
input_data[input_np] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByOneTensorWithTupleOfTensor(Cell): class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
@ -336,6 +398,23 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
return ret return ret
def test_setitem_by_one_tensor_with_tuple_tensors():
net = TensorSetItemByOneTensorWithTupleOfTensor()
input_np = np.random.randint(6, size=(5, 4)).astype(np.int32)
input_ms = Tensor(input_np, mstype.int32)
input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32)
value_0_np = np.zeros((8,), np.float32)
value_1_np = np.ones((8,), np.float32)
value_2_np = np.ones((8,), np.float32)*2
value_0 = Tensor(value_0_np)
value_1 = Tensor(value_1_np)
value_2 = Tensor(value_2_np)
const = np.ones((6, 3, 8)).astype(np.float32)
out = net(input_ms, value_0, value_1, value_2)
input_data[input_np] = (value_0_np, value_1_np, value_2_np)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithNumber(Cell): class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__() super(TensorSetItemByTensorsWithNumber, self).__init__()
@ -349,6 +428,22 @@ class TensorSetItemByTensorsWithNumber(Cell):
return ret return ret
def test_setitem_by_tensors_with_number():
value = 0.0
net = TensorSetItemByTensorsWithNumber(value)
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
out = net(index_0_ms, index_1_ms, index_2_ms)
const = np.ones((6, 7, 8)).astype(np.float32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTensor(Cell): class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__() super(TensorSetItemByTensorsWithTensor, self).__init__()
@ -361,6 +456,23 @@ class TensorSetItemByTensorsWithTensor(Cell):
return ret return ret
def test_setitem_by_tensors_with_tensor():
net = TensorSetItemByTensorsWithTensor()
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
value = np.zeros((4, 5)).astype(np.float32)
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
value_ms = Tensor(value, mstype.float32)
out = net(index_0_ms, index_1_ms, index_2_ms, value_ms)
const = np.ones((6, 7, 8)).astype(np.float32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTensorNumberError(Cell): class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
@ -373,6 +485,17 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell):
return ret return ret
def test_setitem_by_tensors_with_tensor_error():
index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32)
value = Tensor(np.zeros((2, 5)), mstype.float32)
net = TensorSetItemByTensorsWithTensorNumberError()
with pytest.raises(IndexError):
net(index_0, index_1, index_2, index_3, value)
class TensorSetItemByTensorsWithTupleOfNumber(Cell): class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value): def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
@ -386,6 +509,22 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell):
return ret return ret
def test_setitem_by_tensors_with_tuple_of_number():
value = (0.0, 1.1, 2.2, 3.3, 4.4)
net = TensorSetItemByTensorsWithTupleOfNumber(value)
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = value
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTupleOfTensor(Cell): class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
@ -398,6 +537,27 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell):
return ret return ret
def test_setitem_by_tensors_with_tuple_of_tensor():
value_0 = np.zeros((4, 5))
value_1 = np.ones((4, 5))
value_2 = np.ones((4, 5)) * 2
value_0_ms = Tensor(value_0, mstype.float32)
value_1_ms = Tensor(value_1, mstype.float32)
value_2_ms = Tensor(value_2, mstype.float32)
net = TensorSetItemByTensorsWithTupleOfTensor()
index_0 = np.random.randint(6, size=(3, 4, 5))
index_1 = np.random.randint(7, size=(4, 5))
index_2 = np.random.randint(8, size=(5, 3, 4, 5))
index_0_ms = Tensor(index_0, mstype.int32)
index_1_ms = Tensor(index_1, mstype.int32)
index_2_ms = Tensor(index_2, mstype.int32)
input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
input_data[index_0, index_1, index_2] = (value_0, value_1, value_2)
const = np.ones((6, 7, 8)).astype(np.float32)
out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms)
assert np.all(out.asnumpy() == (input_data + const))
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self): def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
@ -410,17 +570,44 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
return ret return ret
class TensorSetItemByMixedTensors(Cell): def test_setitem_by_tensor_with_tuple_of_tensor_error():
def __init__(self): net = TensorSetItemByTensorsWithTupleOfTensorNumberError()
super(TensorSetItemByMixedTensors, self).__init__() index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
self.value = 99.0 value_0 = np.zeros((4, 5))
value_1 = np.ones((4, 5))
value_0_ms = Tensor(value_0, mstype.float32)
value_1_ms = Tensor(value_1, mstype.float32)
with pytest.raises(ValueError):
net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms)
def construct(self, index_0, index_1):
self.param[index_0, index_1, 0:6] = self.value def test_setitem_grad():
ret = self.param + self.const class Net(Cell):
return ret def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(
Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True)
def construct(self, a, b):
a[1:3:1, ::] = b
c = a + self.weight
return c
class GradNet(Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y, sens):
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
net = GradNet(Net())
x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
y = Tensor(np.array([3]).astype(np.float32), mstype.float32)
sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
net(x, y, sens)
class TensorAssignWithSliceError1(Cell): class TensorAssignWithSliceError1(Cell):
@ -475,7 +662,6 @@ class TensorAssignWithSlice(Cell):
def test_tensor_assign(): def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice() net = TensorAssignWithSlice()
net2 = TensorAssignWithSlice2() net2 = TensorAssignWithSlice2()
net_e1 = TensorAssignWithSliceError1() net_e1 = TensorAssignWithSliceError1()
@ -621,7 +807,7 @@ class TensorAssignWithTupleInteger(Cell):
class TensorAssignWithBoolTensorIndex(Cell): class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__() super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.u_scalar = 5 self.u_scalar = 5
def construct(self, a, b, c, u_tensor): def construct(self, a, b, c, u_tensor):
@ -643,8 +829,7 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell): class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__() super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
self.u_scalar = 5 self.u_scalar = 5
def construct(self, a, u_tensor): def construct(self, a, u_tensor):
@ -666,7 +851,40 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
return a return a
def Xtest_tensor_assign_bool_index(): def test_tensor_assign_bool_index_0():
a = np.arange(60).reshape(3, 4, 5)
b = a > 5
c = a < 3
Ta = Tensor(a, dtype=mstype.float32)
Tb = Tensor(b)
Tc = Tensor(c)
u_tensor = Tensor([1], dtype=mstype.float32)
net1 = TensorAssignWithBoolTensorIndex()
out = net1(Ta, Tb, Tc, u_tensor)
res = np.arange(60).reshape(3, 4, 5)
res[c] = 5
res[b] = 1
res = res + np.ones([3, 4, 5])
assert np.all(out.asnumpy() == res)
def test_tensor_assign_bool_index_1():
a = np.arange(60).reshape(3, 4, 5)
Ta = Tensor(a, dtype=mstype.float32)
u_tensor = Tensor([1], dtype=mstype.float32)
net2 = TensorAssignWithBoolTensorIndex2()
out = net2(Ta, u_tensor)
res = np.arange(60).reshape(3, 4, 5)
res[res > 8] = 1
res[res >= 6] = 5
res[res < 3] = 5
res[res <= 5] = 1
res[res == 5] = 5
res = res + np.ones([3, 4, 5])
assert np.all(out.asnumpy() == res)
def test_tensor_assign_bool_index_exception():
a = np.arange(60).reshape(3, 4, 5) a = np.arange(60).reshape(3, 4, 5)
b = a > 5 b = a > 5
c = a < 3 c = a < 3
@ -679,8 +897,6 @@ def Xtest_tensor_assign_bool_index():
u_scalar = 5 u_scalar = 5
net1 = TensorAssignWithBoolTensorIndex() net1 = TensorAssignWithBoolTensorIndex()
net2 = TensorAssignWithBoolTensorIndex2() net2 = TensorAssignWithBoolTensorIndex2()
net1(Ta, Tb, Tc, u_tensor)
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor) net1(Ta, Td, Tc, u_tensor)
with pytest.raises(IndexError): with pytest.raises(IndexError):
@ -695,14 +911,14 @@ def Xtest_tensor_assign_bool_index():
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(Ta, u_tensor_error) net2(Ta, u_tensor_error)
net3 = TensorAssignWithBoolTensorIndexError() net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError): with pytest.raises(IndexError):
net3(Ta, Tb, Tc, u_tensor) net3(Ta, Tb, Tc, u_tensor)
with pytest.raises(AttributeError): with pytest.raises(IndexError):
net3(Ta, Tb, Tc, u_scalar) net3(Ta, Tb, Tc, u_scalar)
net4 = TensorAssignWithBoolTensorIndex2Error() net4 = TensorAssignWithBoolTensorIndex2Error()
with pytest.raises(AttributeError): with pytest.raises(IndexError):
net4(Ta, u_tensor) net4(Ta, u_tensor)
with pytest.raises(AttributeError): with pytest.raises(IndexError):
net4(Ta, u_scalar) net4(Ta, u_scalar)