diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 5a90575e964..eb74fab8391 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -459,8 +459,7 @@ tensor_operator_registry.register("__setitem__", _tensor_setitem) def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): """Set a tensor item by a int tensor with a tensor.""" - updates = _generate_updates_from_tensor(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) + updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.expand_dims(index, -1) return P.TensorScatterUpdate()(data, index, updates) @@ -504,9 +503,9 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): """Set a tensor item by a int tensor with a scalar.""" - index = F.expand_dims(index, 0) updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) - return P.ScatterUpdate()(data, index, updates) + index = F.expand_dims(index, -1) + return P.TensorScatterUpdate()(data, index, updates) def tensor_setitem_by_tensor_with_number(data, index, value): @@ -547,11 +546,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value): def _tensor_setitem_by_tensor_with_tuple(data, index, value): """Set a tensor item by a tensor with a tuple.""" - updates = _generate_updates_from_tuple(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) + updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.expand_dims(index, -1) - result = P.TensorScatterUpdate()(data, index, updates) - return result + return P.TensorScatterUpdate()(data, index, updates) def tensor_setitem_by_slice_with_number(data, input_slice, value):