forked from mindspore-Ecosystem/mindspore
fix nll_loss and cross_entropy
This commit is contained in:
parent
2df40c2deb
commit
ff07db1f69
|
@ -1352,6 +1352,11 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
|
|||
|
||||
if target.ndim == inputs.ndim - 1:
|
||||
target = target.expand_dims(target_dim)
|
||||
if ignore_index is not None:
|
||||
non_pad_mask = _equal(target, ignore_index)
|
||||
target = target.masked_fill(non_pad_mask, 0)
|
||||
else:
|
||||
non_pad_mask = target
|
||||
loss = _neg(_gather_d(inputs, target_dim, target))
|
||||
smooth_loss = _neg(inputs.sum(axis=target_dim, keepdims=True))
|
||||
if weight is not None:
|
||||
|
@ -1360,7 +1365,6 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red
|
|||
else:
|
||||
loss_weights = _ones_like(loss)
|
||||
if ignore_index is not None:
|
||||
non_pad_mask = _equal(target, ignore_index)
|
||||
loss = loss.masked_fill(non_pad_mask, 0.)
|
||||
loss_weights = loss_weights.masked_fill(non_pad_mask, 0.)
|
||||
smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.)
|
||||
|
|
Loading…
Reference in New Issue