!12716 setitem debug to support 0d scalar tensor

From: @yepei6
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-01 15:37:17 +08:00 committed by Gitee
commit 5ea2dc6e69
1 changed files with 2 additions and 0 deletions

View File

@ -510,6 +510,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar.""" """Set a tensor item by a int tensor with a scalar."""
if not F.shape(index):
index = F.expand_dims(index, 0)
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1) index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates) return P.TensorScatterUpdate()(data, index, updates)