From 1df90af6d702e5df5cf772c4fae9ce308f55727d Mon Sep 17 00:00:00 2001 From: wilfChen Date: Thu, 5 Jan 2023 14:14:46 +0800 Subject: [PATCH] setitem by bool support nan --- .../composite/multitype_ops/_compile_utils.py | 8 +++++--- tests/st/ops/ascend/test_tensor_setitem.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py index b7e26e736ac..a81a80573ae 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -836,10 +836,12 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): """Set a tensor item by a bool tensor with a tensor.""" - dtype = F.dtype(data) - u_cast = F.cast(value, dtype) index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape))) - result = u_cast * index + data * F.logical_not(index) + index = F.broadcast_to(index, data.shape) + value = F.cast(value, F.dtype(data)) + value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape))) + value = F.broadcast_to(value, data.shape) + result = F.select(index, value, data) return result diff --git a/tests/st/ops/ascend/test_tensor_setitem.py b/tests/st/ops/ascend/test_tensor_setitem.py index 38520bdf513..738cf26a503 100644 --- a/tests/st/ops/ascend/test_tensor_setitem.py +++ b/tests/st/ops/ascend/test_tensor_setitem.py @@ -38,3 +38,19 @@ def test_tensor_slice_by_bool_broadcast(): data_np[index_np] = value data_tensor[index_tensor] = value assert np.allclose(data_tensor.asnumpy(), data_np) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tensor_slice_by_bool_nan(): + """ + Feature: Tensor-setitem-by-bool support nan. + Description: Tensor-setitem-by-bool support nan. + Expectation: success. + """ + data = Tensor(np.ones([2, 3, 4], np.float32)) + index = Tensor(np.array([False, False])) + data[index] = Tensor([np.nan]) + assert np.allclose(data.asnumpy(), np.ones([2, 3, 4], np.float32))