From 28d598f639defee57d92e3862bc5081bd65da6db Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Sat, 6 Aug 2022 11:41:09 +0800 Subject: [PATCH] fix docs and add check for nn.CrossEntropyLoss --- .../nn/mindspore.nn.CrossEntropyLoss.rst | 7 ++- mindspore/python/mindspore/nn/loss/loss.py | 50 +++++++++++++++++-- .../python/mindspore/ops/function/nn_func.py | 6 +++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/docs/api/api_python/nn/mindspore.nn.CrossEntropyLoss.rst b/docs/api/api_python/nn/mindspore.nn.CrossEntropyLoss.rst index 95c1a793359..ea2f3ab840d 100644 --- a/docs/api/api_python/nn/mindspore.nn.CrossEntropyLoss.rst +++ b/docs/api/api_python/nn/mindspore.nn.CrossEntropyLoss.rst @@ -55,10 +55,9 @@ mindspore.nn.CrossEntropyLoss - **label_smoothing** (float): 标签平滑值,用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。 输入: - - **logits** (Tensor) - 输入预测值,shape为 :math:`(N, C)` 或 :math:`(N, C, H, W)` - (针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。 - - **labels** (Tensor) - 输入目标值,shape为 :math:`(N)` 或 :math:`(N, d_1, d_2, ..., d_K)` - (针对高维数据)。 + - **logits** (Tensor) - 输入预测值,shape为 :math:`(C,)` 、 :math:`(N, C)` 或 :math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。 + - **labels** (Tensor) - 输入目标值。若目标值为类别索引,则shape为 :math:`()` 、 :math:`(N)` 或 :math:`(N, d_1, d_2, ..., d_K)` ,数据类型仅支持int32。 + 若目标值为类别概率,则shape为 :math:`(C,)` 、 :math:`(N, C)` 或 :math:`(N, C, d_1, d_2, ..., d_K)` ,数据类型仅支持float32或float16。 返回: Tensor,一个数据类型与logits相同的Tensor。 diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index 821341a0964..0aedb71f02f 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -1917,6 +1917,41 @@ class NLLLoss(LossBase): return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction) +@constexpr +def _check_cross_entropy_inputs(logits_shape, label_shape, \ + logits_rank, label_rank, \ + logits_dtype, label_dtype, \ + prim_name=None): + """Internal function, used to check whether the shape of logits and labels meets the requirements.""" + validator.check_type_name('logits', logits_dtype, [mstype.float16, mstype.float32], prim_name) + + msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The" + if logits_rank == label_rank: + validator.check_type_name('labels', label_dtype, [mstype.float16, mstype.float32], prim_name) + if logits_shape != label_shape: + raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), " + f"and the shape of 'labels' should be (N, C, d_0, d_1, ...), " + f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.") + elif label_rank == logits_rank - 1: + validator.check_type_name('labels', label_dtype, [mstype.int32], prim_name) + logits_shape_new = (logits_shape[0], *logits_shape[2:]) + if logits_shape_new != label_shape: + raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), " + f"and the shape of 'labels' should be (N, d_0, d_1, ...), " + f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.") + else: + raise ValueError(f"{msg_prefix} rank of 'logits' and 'labels' should be:\n" + f"1. 'logits.ndim == labels.ndim' for probabilities, \n" + f"2. 'logits.ndim - 1 == labels.ndim' for class indices, \n" + f"but get 'logits' rank: {logits_rank} and 'labels' rank: {label_rank}.") + + +@constexpr +def _cross_entropy_ignore_index_warning(prim_name): + """Internal function, used to warrning when ignore_index > 0 for probabilities.""" + log.warning(f"For \'{prim_name}\', 'ignore_index' does not work when 'labels' is Probability.") + + class CrossEntropyLoss(LossBase): r""" The cross entropy loss between input and target. @@ -1979,9 +2014,12 @@ class CrossEntropyLoss(LossBase): from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0. Inputs: - - **logits** (Tensor) - Tensor of shape :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` - in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`. Data type must be float16 or float32. - - **labels** (Tensor) -:math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for high-dimensional data. + - **logits** (Tensor) - Tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K), + where `C = number of classes`. Data type must be float16 or float32. + - **labels** (Tensor) - For class indices, tensor of shape :math:`()`, :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` , data type must be int32. + For probabilities, tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K), + data type must be float16 or float32. Returns: Tensor, the computed cross entropy loss value. @@ -2031,4 +2069,10 @@ class CrossEntropyLoss(LossBase): def construct(self, logits, labels): _check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) + _check_cross_entropy_inputs(logits.shape, labels.shape, \ + logits.ndim, labels.ndim, \ + logits.dtype, labels.dtype, \ + self.cls_name) + if logits.ndim == labels.ndim and self.ignore_index > 0: + _cross_entropy_ignore_index_warning(self.cls_name) return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing) diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 58807dbaba6..e5ede347e48 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -1354,6 +1354,10 @@ def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', la if weight is None: weight = _ones_like(inputs) + else: + broadcast_shape = [1 for _ in range(inputs.ndim)] + broadcast_shape[1] = weight.shape[0] + weight = weight.reshape(broadcast_shape) if reduction == 'mean': return -(inputs * target * weight).sum() / (inputs.size / n_classes) @@ -1421,6 +1425,8 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing) elif ndim == 4: ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 1: + ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing) else: n = inputs.shape[0] c = inputs.shape[1]