!41234 avoid masked_fill dtype support on ascend

Merge pull request !41234 from 吕昱峰(Nate.River)/master
This commit is contained in:
i-robot 2022-08-31 14:03:42 +00:00 committed by Gitee
commit 603e7027d3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 1 deletions

View File

@ -1683,6 +1683,13 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
return ret
def _masked_fill(inputs, mask, value):
_fill = _get_cache_prim(P.Fill)()
_select = _get_cache_prim(P.Select)()
masked_value = _fill(inputs.dtype, inputs.shape, value)
return _select(mask, masked_value, inputs)
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
"""nll loss inner function"""
_neg = _get_cache_prim(P.Neg)()
@ -1695,7 +1702,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
target = target.expand_dims(target_dim)
if ignore_index is not None:
non_pad_mask = _equal(target, ignore_index)
target = target.masked_fill(non_pad_mask, 0)
target = _masked_fill(target, non_pad_mask, 0)
else:
non_pad_mask = target
loss = _neg(_gather_d(inputs, target_dim, target))