!49461 fix masked_fill caused neg loss on CUDA 10.1

Merge pull request !49461 from 吕昱峰(Nate.River)/masked_fill
This commit is contained in:
i-robot 2023-03-02 09:33:32 +00:00 committed by Gitee
commit 42e5324b50
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 7 additions and 3 deletions

View File

@ -2707,8 +2707,11 @@ def masked_fill(x, mask, value):
Fills elements of Tensor with value where mask is True.
"""
check_is_tensor(mask)
if not isinstance(value, Tensor):
value = Tensor(value)
check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
return F.masked_fill(x, mask, value)
masked_value = P.FillV2()(x.shape, value.astype(x.dtype))
return P.Select()(mask, masked_value, x)
def col2im(*inputs):

View File

@ -4845,10 +4845,11 @@ def masked_fill(input_x, mask, value):
>>> print(output)
[0.5 0.5 3. 0.5]
"""
_fill = _get_cache_prim(P.FillV2)()
if isinstance(value, (float, int)) and isinstance(input_x, Tensor):
value = scalar_to_tensor_(value, input_x.dtype)
masked_fill_ = _get_cache_prim(P.MaskedFill)()
return masked_fill_(input_x, mask, value)
masked_value = _fill(input_x.shape, value)
return select(mask, masked_value, input_x)
def diag(input_x):