From ff07db1f69798b15be78134ae5cec4453b010887 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sat, 27 Aug 2022 08:45:42 +0800 Subject: [PATCH] fix nll_loss and cross_entropy --- mindspore/python/mindspore/ops/function/nn_func.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 82a02d82bf2..2c1193dc3f9 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -1352,6 +1352,11 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red if target.ndim == inputs.ndim - 1: target = target.expand_dims(target_dim) + if ignore_index is not None: + non_pad_mask = _equal(target, ignore_index) + target = target.masked_fill(non_pad_mask, 0) + else: + non_pad_mask = target loss = _neg(_gather_d(inputs, target_dim, target)) smooth_loss = _neg(inputs.sum(axis=target_dim, keepdims=True)) if weight is not None: @@ -1360,7 +1365,6 @@ def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, red else: loss_weights = _ones_like(loss) if ignore_index is not None: - non_pad_mask = _equal(target, ignore_index) loss = loss.masked_fill(non_pad_mask, 0.) loss_weights = loss_weights.masked_fill(non_pad_mask, 0.) smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.)