forked from mindspore-Ecosystem/mindspore
!41234 avoid masked_fill dtype support on ascend
Merge pull request !41234 from 吕昱峰(Nate.River)/master
This commit is contained in:
commit
603e7027d3
|
@ -1683,6 +1683,13 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
|
|||
return ret
|
||||
|
||||
|
||||
def _masked_fill(inputs, mask, value):
|
||||
_fill = _get_cache_prim(P.Fill)()
|
||||
_select = _get_cache_prim(P.Select)()
|
||||
masked_value = _fill(inputs.dtype, inputs.shape, value)
|
||||
return _select(mask, masked_value, inputs)
|
||||
|
||||
|
||||
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
|
||||
"""nll loss inner function"""
|
||||
_neg = _get_cache_prim(P.Neg)()
|
||||
|
@ -1695,7 +1702,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
|
|||
target = target.expand_dims(target_dim)
|
||||
if ignore_index is not None:
|
||||
non_pad_mask = _equal(target, ignore_index)
|
||||
target = target.masked_fill(non_pad_mask, 0)
|
||||
target = _masked_fill(target, non_pad_mask, 0)
|
||||
else:
|
||||
non_pad_mask = target
|
||||
loss = _neg(_gather_d(inputs, target_dim, target))
|
||||
|
|
Loading…
Reference in New Issue