forked from mindspore-Ecosystem/mindspore
setitem debug
This commit is contained in:
parent
228a64de0f
commit
c22081abb3
|
@ -459,8 +459,7 @@ tensor_operator_registry.register("__setitem__", _tensor_setitem)
|
||||||
|
|
||||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||||
"""Set a tensor item by a int tensor with a tensor."""
|
"""Set a tensor item by a int tensor with a tensor."""
|
||||||
updates = _generate_updates_from_tensor(data, index, value,
|
updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
||||||
index = F.expand_dims(index, -1)
|
index = F.expand_dims(index, -1)
|
||||||
return P.TensorScatterUpdate()(data, index, updates)
|
return P.TensorScatterUpdate()(data, index, updates)
|
||||||
|
|
||||||
|
@ -504,9 +503,9 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
||||||
|
|
||||||
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||||
"""Set a tensor item by a int tensor with a scalar."""
|
"""Set a tensor item by a int tensor with a scalar."""
|
||||||
index = F.expand_dims(index, 0)
|
|
||||||
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
return P.ScatterUpdate()(data, index, updates)
|
index = F.expand_dims(index, -1)
|
||||||
|
return P.TensorScatterUpdate()(data, index, updates)
|
||||||
|
|
||||||
|
|
||||||
def tensor_setitem_by_tensor_with_number(data, index, value):
|
def tensor_setitem_by_tensor_with_number(data, index, value):
|
||||||
|
@ -547,11 +546,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
|
||||||
|
|
||||||
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||||
"""Set a tensor item by a tensor with a tuple."""
|
"""Set a tensor item by a tensor with a tuple."""
|
||||||
updates = _generate_updates_from_tuple(data, index, value,
|
updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
|
||||||
index = F.expand_dims(index, -1)
|
index = F.expand_dims(index, -1)
|
||||||
result = P.TensorScatterUpdate()(data, index, updates)
|
return P.TensorScatterUpdate()(data, index, updates)
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
def tensor_setitem_by_slice_with_number(data, input_slice, value):
|
||||||
|
|
Loading…
Reference in New Issue