!37614 fix masked_fill with float value

Merge pull request !37614 from 范吉斌/fix_masked_fill
This commit is contained in:
i-robot 2022-07-11 02:54:50 +00:00 committed by Gitee
commit 682d76081d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 41 additions and 2 deletions

View File

@ -2699,6 +2699,8 @@ class Tensor(Tensor_):
[0. 1. 0. 0.]
"""
self._init_check()
if isinstance(value, (float, int)):
value = tensor_operator_registry.get("scalar_to_tensor")(value, self.dtype)
if not isinstance(mask, Tensor):
raise TypeError("For 'Tensor.masked_fill', the type of the argument 'mask' must be Tensor, but "
"got {}.".format(type(mask)))

View File

@ -753,7 +753,6 @@ class AdaptiveMaxPool3d(Cell):
`ouput_size` can be a tuple with 3 elements, or a single D for :math:`(D, D, D)`. :math:`D`,
:math:`H` and :math:`W` can be int or None which means the output size is the same as that of
the input.
return_indices (bool): If `return_indices` is True, the indices of max value would be output.
Default: False.

View File

@ -71,7 +71,6 @@ tensor_scatter_max_ = P.TensorScatterMax()
scalar_to_array_ = P.ScalarToArray()
scalar_to_tensor_ = P.ScalarToTensor()
tuple_to_array_ = P.TupleToArray()
masked_fill_ = P.MaskedFill()
masked_select_ = P.MaskedSelect()
matrix_band_part_ = P.array_ops.MatrixBandPart()
ger_ = P.Ger()
@ -3494,6 +3493,9 @@ def masked_fill(input_x, mask, value):
>>> print(output)
[0.5 0.5 3. 0.5]
"""
if isinstance(value, (float, int)) and isinstance(input_x, Tensor):
value = scalar_to_tensor_(value, input_x.dtype)
masked_fill_ = _get_cache_prim(P.MaskedFill)()
return masked_fill_(input_x, mask, value)

View File

@ -834,6 +834,7 @@ tensor_operator_registry.register('sum', P.ReduceSum)
tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('select', P.Select)
tensor_operator_registry.register('zeros_like', P.ZerosLike)
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
tensor_operator_registry.register('masked_fill', masked_fill)
tensor_operator_registry.register('masked_select', masked_select)
tensor_operator_registry.register('nonzero', nonzero)

View File

@ -20,6 +20,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -131,6 +132,40 @@ def test_maskedfill_float_value():
maskedfill_value(0.5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_func_masked_fill_float():
"""
Feature: Test func masked_fill.
Description: Test func masked_fill api with float value.
Expectation: The result match to expect.
"""
inputs = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float16))
mask = Tensor(np.array([[True, True, False, True], [False, False, True, False]]).astype(np.bool))
value = 22
expect = np.array([[22, 22, 3, 22], [5, 6, 22, 8]]).astype(np.float16)
output = F.masked_fill(inputs, mask, value)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tensor_masked_fill_float():
"""
Feature: Test Tensor masked_fill.
Description: Test Tensor masked_fill api with float value.
Expectation: The result match to expect.
"""
inputs = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float16))
mask = Tensor(np.array([[True, True, False, True], [False, False, True, False]]).astype(np.bool))
value = 22
output = inputs.masked_fill(mask, value)
expect = np.array([[22, 22, 3, 22], [5, 6, 22, 8]]).astype(np.float16)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard