forked from mindspore-Ecosystem/mindspore
!41579 restore masked_fill for nll_loss
Merge pull request !41579 from 吕昱峰(Nate.River)/r1.8
This commit is contained in:
commit
9a6699f060
|
@ -1598,8 +1598,11 @@ class HuberLoss(LossBase):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_nll_loss_shape(logits_shape, label_shape, prim_name=None):
|
||||
def _check_nll_loss_inputs(logits_shape, label_shape, logits_dtype, label_dtype, prim_name=None):
|
||||
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
|
||||
validator.check_type_name('logits', logits_dtype, [mstype.float16, mstype.float32], prim_name)
|
||||
validator.check_type_name('labels', label_dtype, [mstype.int32], prim_name)
|
||||
|
||||
logits_shape_new = (logits_shape[0], *logits_shape[2:])
|
||||
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
||||
if logits_shape_new != label_shape:
|
||||
|
@ -1684,7 +1687,7 @@ class NLLLoss(LossBase):
|
|||
def construct(self, logits, labels):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
_check_nll_loss_shape(logits.shape, labels.shape, self.cls_name)
|
||||
_check_nll_loss_inputs(logits.shape, labels.shape, logits.dtype, labels.dtype, self.cls_name)
|
||||
return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction)
|
||||
|
||||
|
||||
|
|
|
@ -1342,13 +1342,6 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
|
|||
return ret
|
||||
|
||||
|
||||
def _masked_fill(inputs, mask, value):
|
||||
_fill = _get_cache_prim(P.Fill)()
|
||||
_select = _get_cache_prim(P.Select)()
|
||||
masked_value = _fill(inputs.dtype, inputs.shape, value)
|
||||
return _select(mask, masked_value, inputs)
|
||||
|
||||
|
||||
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
|
||||
"""nll loss inner function"""
|
||||
_neg = _get_cache_prim(P.Neg)()
|
||||
|
@ -1361,7 +1354,7 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
|
|||
target = target.expand_dims(target_dim)
|
||||
if ignore_index is not None:
|
||||
non_pad_mask = _equal(target, ignore_index)
|
||||
target = _masked_fill(target, non_pad_mask, 0)
|
||||
target = target.masked_fill(non_pad_mask, 0)
|
||||
else:
|
||||
non_pad_mask = target
|
||||
loss = _neg(_gather_d(inputs, target_dim, target))
|
||||
|
|
Loading…
Reference in New Issue