forked from mindspore-Ecosystem/mindspore
!49461 fix masked_fill caused neg loss on CUDA 10.1
Merge pull request !49461 from 吕昱峰(Nate.River)/masked_fill
This commit is contained in:
commit
42e5324b50
|
@ -2707,8 +2707,11 @@ def masked_fill(x, mask, value):
|
|||
Fills elements of Tensor with value where mask is True.
|
||||
"""
|
||||
check_is_tensor(mask)
|
||||
if not isinstance(value, Tensor):
|
||||
value = Tensor(value)
|
||||
check_type_name('mask', mask.dtype, [mstype.bool_], "Tensor")
|
||||
return F.masked_fill(x, mask, value)
|
||||
masked_value = P.FillV2()(x.shape, value.astype(x.dtype))
|
||||
return P.Select()(mask, masked_value, x)
|
||||
|
||||
|
||||
def col2im(*inputs):
|
||||
|
|
|
@ -4845,10 +4845,11 @@ def masked_fill(input_x, mask, value):
|
|||
>>> print(output)
|
||||
[0.5 0.5 3. 0.5]
|
||||
"""
|
||||
_fill = _get_cache_prim(P.FillV2)()
|
||||
if isinstance(value, (float, int)) and isinstance(input_x, Tensor):
|
||||
value = scalar_to_tensor_(value, input_x.dtype)
|
||||
masked_fill_ = _get_cache_prim(P.MaskedFill)()
|
||||
return masked_fill_(input_x, mask, value)
|
||||
masked_value = _fill(input_x.shape, value)
|
||||
return select(mask, masked_value, input_x)
|
||||
|
||||
|
||||
def diag(input_x):
|
||||
|
|
Loading…
Reference in New Issue