setitem by bool support nan

This commit is contained in:
wilfChen 2023-01-05 14:14:46 +08:00
parent 24b3ad9ef0
commit 1df90af6d7
2 changed files with 21 additions and 3 deletions

View File

@ -836,10 +836,12 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor.""" """Set a tensor item by a bool tensor with a tensor."""
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape))) index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
result = u_cast * index + data * F.logical_not(index) index = F.broadcast_to(index, data.shape)
value = F.cast(value, F.dtype(data))
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
value = F.broadcast_to(value, data.shape)
result = F.select(index, value, data)
return result return result

View File

@ -38,3 +38,19 @@ def test_tensor_slice_by_bool_broadcast():
data_np[index_np] = value data_np[index_np] = value
data_tensor[index_tensor] = value data_tensor[index_tensor] = value
assert np.allclose(data_tensor.asnumpy(), data_np) assert np.allclose(data_tensor.asnumpy(), data_np)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tensor_slice_by_bool_nan():
"""
Feature: Tensor-setitem-by-bool support nan.
Description: Tensor-setitem-by-bool support nan.
Expectation: success.
"""
data = Tensor(np.ones([2, 3, 4], np.float32))
index = Tensor(np.array([False, False]))
data[index] = Tensor([np.nan])
assert np.allclose(data.asnumpy(), np.ones([2, 3, 4], np.float32))