forked from mindspore-Ecosystem/mindspore
!42915 【1.9】修复切片索引为tensor类型的bug
Merge pull request !42915 from huoxinyou/0926tensorslice19
This commit is contained in:
commit
d8fe6a2489
|
@ -785,8 +785,6 @@ def _generate_updates_from_tensor(data, index, value, op_type):
|
|||
value = value.astype(data.dtype)
|
||||
if is_shape_unknown(F.shape(data)):
|
||||
data_shape = F.dyn_shape(data)
|
||||
if F.rank(index) == 0:
|
||||
index = F.expand_dims(index, -1)
|
||||
index_shape = F.dyn_shape(index)
|
||||
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
||||
updates = dynamic_broadcast_to(value, updates_shape)
|
||||
|
@ -843,12 +841,13 @@ def tensor_setitem_by_ellipsis(self, index, value):
|
|||
|
||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by an int tensor with a tensor."""
|
||||
if F.rank(index) == 0:
|
||||
index = F.expand_dims(index, -1)
|
||||
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
index = F.select(index < 0, index + F.shape(data)[0], index)
|
||||
index = F.expand_dims(index, -1)
|
||||
if F.rank(index) < 2:
|
||||
index = F.expand_dims(index, 0)
|
||||
if F.rank(updates) == 0:
|
||||
updates = F.expand_dims(updates, 0)
|
||||
return F.tensor_scatter_update(data, index, updates)
|
||||
|
||||
|
|
Loading…
Reference in New Issue