!42915 【1.9】修复切片索引为tensor类型的bug

Merge pull request !42915 from huoxinyou/0926tensorslice19
This commit is contained in:
i-robot 2022-09-28 01:44:36 +00:00 committed by Gitee
commit d8fe6a2489
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 3 deletions

View File

@ -785,8 +785,6 @@ def _generate_updates_from_tensor(data, index, value, op_type):
value = value.astype(data.dtype)
if is_shape_unknown(F.shape(data)):
data_shape = F.dyn_shape(data)
if F.rank(index) == 0:
index = F.expand_dims(index, -1)
index_shape = F.dyn_shape(index)
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
updates = dynamic_broadcast_to(value, updates_shape)
@ -843,12 +841,13 @@ def tensor_setitem_by_ellipsis(self, index, value):
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by an int tensor with a tensor."""
if F.rank(index) == 0:
index = F.expand_dims(index, -1)
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.select(index < 0, index + F.shape(data)[0], index)
index = F.expand_dims(index, -1)
if F.rank(index) < 2:
index = F.expand_dims(index, 0)
if F.rank(updates) == 0:
updates = F.expand_dims(updates, 0)
return F.tensor_scatter_update(data, index, updates)