From 6b888373138412f4f2a850b7e760bf20c7c7335c Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 27 Feb 2023 15:48:58 +0800 Subject: [PATCH] fix masked_fill caused neg loss on CUDA 10.1 --- mindspore/python/mindspore/_extends/parse/standard_method.py | 5 ++++- mindspore/python/mindspore/ops/function/array_func.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 4711e8aaa00..5cdca9d523d 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -2712,8 +2712,11 @@ def masked_fill(x, mask, value): Fills elements of Tensor with value where mask is True. """ check_is_tensor(mask) + if not isinstance(value, Tensor): + value = Tensor(value) check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor") - return F.masked_fill(x, mask, value) + masked_value = P.FillV2()(x.shape, value.astype(x.dtype)) + return P.Select()(mask, masked_value, x) def col2im(*inputs): diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index 59e3e8a7857..c5ac8fc5934 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -4844,10 +4844,11 @@ def masked_fill(input_x, mask, value): >>> print(output) [0.5 0.5 3. 0.5] """ + _fill = _get_cache_prim(P.FillV2)() if isinstance(value, (float, int)) and isinstance(input_x, Tensor): value = scalar_to_tensor_(value, input_x.dtype) - masked_fill_ = _get_cache_prim(P.MaskedFill)() - return masked_fill_(input_x, mask, value) + masked_value = _fill(input_x.shape, value) + return select(mask, masked_value, input_x) def diag(input_x):