forked from mindspore-Ecosystem/mindspore
setitem by bool support nan
This commit is contained in:
parent
24b3ad9ef0
commit
1df90af6d7
|
@ -836,10 +836,12 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||||
|
|
||||||
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
||||||
"""Set a tensor item by a bool tensor with a tensor."""
|
"""Set a tensor item by a bool tensor with a tensor."""
|
||||||
dtype = F.dtype(data)
|
|
||||||
u_cast = F.cast(value, dtype)
|
|
||||||
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
||||||
result = u_cast * index + data * F.logical_not(index)
|
index = F.broadcast_to(index, data.shape)
|
||||||
|
value = F.cast(value, F.dtype(data))
|
||||||
|
value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
|
||||||
|
value = F.broadcast_to(value, data.shape)
|
||||||
|
result = F.select(index, value, data)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,3 +38,19 @@ def test_tensor_slice_by_bool_broadcast():
|
||||||
data_np[index_np] = value
|
data_np[index_np] = value
|
||||||
data_tensor[index_tensor] = value
|
data_tensor[index_tensor] = value
|
||||||
assert np.allclose(data_tensor.asnumpy(), data_np)
|
assert np.allclose(data_tensor.asnumpy(), data_np)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_tensor_slice_by_bool_nan():
|
||||||
|
"""
|
||||||
|
Feature: Tensor-setitem-by-bool support nan.
|
||||||
|
Description: Tensor-setitem-by-bool support nan.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
data = Tensor(np.ones([2, 3, 4], np.float32))
|
||||||
|
index = Tensor(np.array([False, False]))
|
||||||
|
data[index] = Tensor([np.nan])
|
||||||
|
assert np.allclose(data.asnumpy(), np.ones([2, 3, 4], np.float32))
|
||||||
|
|
Loading…
Reference in New Issue