!47487 setitem-by-bool support nan

Merge pull request !47487 from chenweifeng/setitem-by-bool-support-nan
This commit is contained in:
i-robot 2023-01-06 02:04:45 +00:00 committed by Gitee
commit 1243ad2081
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 21 additions and 3 deletions

View File

@ -836,10 +836,12 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
result = u_cast * index + data * F.logical_not(index)
index = F.broadcast_to(index, data.shape)
value = F.cast(value, F.dtype(data))
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
value = F.broadcast_to(value, data.shape)
result = F.select(index, value, data)
return result

View File

@ -38,3 +38,19 @@ def test_tensor_slice_by_bool_broadcast():
data_np[index_np] = value
data_tensor[index_tensor] = value
assert np.allclose(data_tensor.asnumpy(), data_np)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_tensor_slice_by_bool_nan():
"""
Feature: Tensor-setitem-by-bool support nan.
Description: Tensor-setitem-by-bool support nan.
Expectation: success.
"""
data = Tensor(np.ones([2, 3, 4], np.float32))
index = Tensor(np.array([False, False]))
data[index] = Tensor([np.nan])
assert np.allclose(data.asnumpy(), np.ones([2, 3, 4], np.float32))