!41579 restore masked_fill for nll_loss

Merge pull request !41579 from 吕昱峰(Nate.River)/r1.8
This commit is contained in:
i-robot 2022-09-07 03:34:54 +00:00 committed by Gitee
commit 9a6699f060
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 10 deletions

View File

@ -1598,8 +1598,11 @@ class HuberLoss(LossBase):
@constexpr
def _check_nll_loss_shape(logits_shape, label_shape, prim_name=None):
def _check_nll_loss_inputs(logits_shape, label_shape, logits_dtype, label_dtype, prim_name=None):
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
validator.check_type_name('logits', logits_dtype, [mstype.float16, mstype.float32], prim_name)
validator.check_type_name('labels', label_dtype, [mstype.int32], prim_name)
logits_shape_new = (logits_shape[0], *logits_shape[2:])
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
if logits_shape_new != label_shape:
@ -1684,7 +1687,7 @@ class NLLLoss(LossBase):
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
_check_nll_loss_shape(logits.shape, labels.shape, self.cls_name)
_check_nll_loss_inputs(logits.shape, labels.shape, logits.dtype, labels.dtype, self.cls_name)
return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction)

View File

@ -1342,13 +1342,6 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
return ret
def _masked_fill(inputs, mask, value):
_fill = _get_cache_prim(P.Fill)()
_select = _get_cache_prim(P.Select)()
masked_value = _fill(inputs.dtype, inputs.shape, value)
return _select(mask, masked_value, inputs)
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
"""nll loss inner function"""
_neg = _get_cache_prim(P.Neg)()
@ -1361,7 +1354,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
target = target.expand_dims(target_dim)
if ignore_index is not None:
non_pad_mask = _equal(target, ignore_index)
target = _masked_fill(target, non_pad_mask, 0)
target = target.masked_fill(non_pad_mask, 0)
else:
non_pad_mask = target
loss = _neg(_gather_d(inputs, target_dim, target))