From 553b8491a1282c82be3f4f9b15d96a7043246fd5 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Tue, 27 Sep 2022 11:39:41 +0800 Subject: [PATCH] fixed tensor slice bug --- .../mindspore/ops/composite/multitype_ops/_compile_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py index 3898d4b49a3..5b235f8f939 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -785,8 +785,6 @@ def _generate_updates_from_tensor(data, index, value, op_type): value = value.astype(data.dtype) if is_shape_unknown(F.shape(data)): data_shape = F.dyn_shape(data) - if F.rank(index) == 0: - index = F.expand_dims(index, -1) index_shape = F.dyn_shape(index) updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True) updates = dynamic_broadcast_to(value, updates_shape) @@ -843,12 +841,13 @@ def tensor_setitem_by_ellipsis(self, index, value): def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): """Set a tensor item by an int tensor with a tensor.""" + if F.rank(index) == 0: + index = F.expand_dims(index, -1) updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.select(index < 0, index + F.shape(data)[0], index) index = F.expand_dims(index, -1) if F.rank(index) < 2: index = F.expand_dims(index, 0) - if F.rank(updates) == 0: updates = F.expand_dims(updates, 0) return F.tensor_scatter_update(data, index, updates)