diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 0a62e19a7ea..d522ac31bcc 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -2707,8 +2707,11 @@ def masked_fill(x, mask, value): Fills elements of Tensor with value where mask is True. """ check_is_tensor(mask) + if not isinstance(value, Tensor): + value = Tensor(value) check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor") - return F.masked_fill(x, mask, value) + masked_value = P.FillV2()(x.shape, value.astype(x.dtype)) + return P.Select()(mask, masked_value, x) def col2im(*inputs): diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 01cb762e2eb..8330d7f6a6d 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -4845,10 +4845,11 @@ def masked_fill(input_x, mask, value): >>> print(output) [0.5 0.5 3. 0.5] """ + _fill = _get_cache_prim(P.FillV2)() if isinstance(value, (float, int)) and isinstance(input_x, Tensor): value = scalar_to_tensor_(value, input_x.dtype) - masked_fill_ = _get_cache_prim(P.MaskedFill)() - return masked_fill_(input_x, mask, value) + masked_value = _fill(input_x.shape, value) + return select(mask, masked_value, input_x) def diag(input_x):