setitem debug

This commit is contained in:
yepei6 2021-02-07 19:50:36 +08:00 committed by Gitee
parent 228a64de0f
commit c22081abb3
1 changed files with 5 additions and 8 deletions

View File

@ -459,8 +459,7 @@ tensor_operator_registry.register("__setitem__", _tensor_setitem)
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor.""" """Set a tensor item by a int tensor with a tensor."""
updates = _generate_updates_from_tensor(data, index, value, updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1) index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates) return P.TensorScatterUpdate()(data, index, updates)
@ -504,9 +503,9 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar.""" """Set a tensor item by a int tensor with a scalar."""
index = F.expand_dims(index, 0)
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
return P.ScatterUpdate()(data, index, updates) index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
def tensor_setitem_by_tensor_with_number(data, index, value): def tensor_setitem_by_tensor_with_number(data, index, value):
@ -547,11 +546,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
def _tensor_setitem_by_tensor_with_tuple(data, index, value): def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple.""" """Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_tuple(data, index, value, updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1) index = F.expand_dims(index, -1)
result = P.TensorScatterUpdate()(data, index, updates) return P.TensorScatterUpdate()(data, index, updates)
return result
def tensor_setitem_by_slice_with_number(data, input_slice, value): def tensor_setitem_by_slice_with_number(data, input_slice, value):