!37614 fix masked_fill with float value
Merge pull request !37614 from 范吉斌/fix_masked_fill
This commit is contained in:
commit
682d76081d
|
@ -2699,6 +2699,8 @@ class Tensor(Tensor_):
|
|||
[0. 1. 0. 0.]
|
||||
"""
|
||||
self._init_check()
|
||||
if isinstance(value, (float, int)):
|
||||
value = tensor_operator_registry.get("scalar_to_tensor")(value, self.dtype)
|
||||
if not isinstance(mask, Tensor):
|
||||
raise TypeError("For 'Tensor.masked_fill', the type of the argument 'mask' must be Tensor, but "
|
||||
"got {}.".format(type(mask)))
|
||||
|
|
|
@ -753,7 +753,6 @@ class AdaptiveMaxPool3d(Cell):
|
|||
`ouput_size` can be a tuple with 3 elements, or a single D for :math:`(D, D, D)`. :math:`D`,
|
||||
:math:`H` and :math:`W` can be int or None which means the output size is the same as that of
|
||||
the input.
|
||||
|
||||
return_indices (bool): If `return_indices` is True, the indices of max value would be output.
|
||||
Default: False.
|
||||
|
||||
|
|
|
@ -71,7 +71,6 @@ tensor_scatter_max_ = P.TensorScatterMax()
|
|||
scalar_to_array_ = P.ScalarToArray()
|
||||
scalar_to_tensor_ = P.ScalarToTensor()
|
||||
tuple_to_array_ = P.TupleToArray()
|
||||
masked_fill_ = P.MaskedFill()
|
||||
masked_select_ = P.MaskedSelect()
|
||||
matrix_band_part_ = P.array_ops.MatrixBandPart()
|
||||
ger_ = P.Ger()
|
||||
|
@ -3494,6 +3493,9 @@ def masked_fill(input_x, mask, value):
|
|||
>>> print(output)
|
||||
[0.5 0.5 3. 0.5]
|
||||
"""
|
||||
if isinstance(value, (float, int)) and isinstance(input_x, Tensor):
|
||||
value = scalar_to_tensor_(value, input_x.dtype)
|
||||
masked_fill_ = _get_cache_prim(P.MaskedFill)()
|
||||
return masked_fill_(input_x, mask, value)
|
||||
|
||||
|
||||
|
|
|
@ -834,6 +834,7 @@ tensor_operator_registry.register('sum', P.ReduceSum)
|
|||
tensor_operator_registry.register('split', P.Split)
|
||||
tensor_operator_registry.register('select', P.Select)
|
||||
tensor_operator_registry.register('zeros_like', P.ZerosLike)
|
||||
tensor_operator_registry.register('scalar_to_tensor', scalar_to_tensor)
|
||||
tensor_operator_registry.register('masked_fill', masked_fill)
|
||||
tensor_operator_registry.register('masked_select', masked_select)
|
||||
tensor_operator_registry.register('nonzero', nonzero)
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -131,6 +132,40 @@ def test_maskedfill_float_value():
|
|||
maskedfill_value(0.5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_func_masked_fill_float():
|
||||
"""
|
||||
Feature: Test func masked_fill.
|
||||
Description: Test func masked_fill api with float value.
|
||||
Expectation: The result match to expect.
|
||||
"""
|
||||
inputs = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float16))
|
||||
mask = Tensor(np.array([[True, True, False, True], [False, False, True, False]]).astype(np.bool))
|
||||
value = 22
|
||||
expect = np.array([[22, 22, 3, 22], [5, 6, 22, 8]]).astype(np.float16)
|
||||
output = F.masked_fill(inputs, mask, value)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_masked_fill_float():
|
||||
"""
|
||||
Feature: Test Tensor masked_fill.
|
||||
Description: Test Tensor masked_fill api with float value.
|
||||
Expectation: The result match to expect.
|
||||
"""
|
||||
inputs = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float16))
|
||||
mask = Tensor(np.array([[True, True, False, True], [False, False, True, False]]).astype(np.bool))
|
||||
value = 22
|
||||
output = inputs.masked_fill(mask, value)
|
||||
expect = np.array([[22, 22, 3, 22], [5, 6, 22, 8]]).astype(np.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue