forked from mindspore-Ecosystem/mindspore
setitem by bool support nan
This commit is contained in:
parent
24b3ad9ef0
commit
1df90af6d7
|
@ -836,10 +836,12 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
|||
|
||||
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
||||
result = u_cast * index + data * F.logical_not(index)
|
||||
index = F.broadcast_to(index, data.shape)
|
||||
value = F.cast(value, F.dtype(data))
|
||||
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
||||
value = F.broadcast_to(value, data.shape)
|
||||
result = F.select(index, value, data)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -38,3 +38,19 @@ def test_tensor_slice_by_bool_broadcast():
|
|||
data_np[index_np] = value
|
||||
data_tensor[index_tensor] = value
|
||||
assert np.allclose(data_tensor.asnumpy(), data_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_slice_by_bool_nan():
|
||||
"""
|
||||
Feature: Tensor-setitem-by-bool support nan.
|
||||
Description: Tensor-setitem-by-bool support nan.
|
||||
Expectation: success.
|
||||
"""
|
||||
data = Tensor(np.ones([2, 3, 4], np.float32))
|
||||
index = Tensor(np.array([False, False]))
|
||||
data[index] = Tensor([np.nan])
|
||||
assert np.allclose(data.asnumpy(), np.ones([2, 3, 4], np.float32))
|
||||
|
|
Loading…
Reference in New Issue