!39904 fix docs and add check for nn.CrossEntropyLoss

Merge pull request !39904 from 吕昱峰(Nate.River)/master
This commit is contained in:
i-robot 2022-08-08 02:26:40 +00:00 committed by Gitee
commit ed1e4cc7b9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 56 additions and 7 deletions

View File

@ -55,10 +55,9 @@ mindspore.nn.CrossEntropyLoss
- **label_smoothing** (float): 标签平滑值用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。
输入:
- **logits** (Tensor) - 输入预测值shape为 :math:`(N, C)`:math:`(N, C, H, W)`
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。
- **labels** (Tensor) - 输入目标值shape为 :math:`(N)`:math:`(N, d_1, d_2, ..., d_K)`
(针对高维数据)。
- **logits** (Tensor) - 输入预测值shape为 :math:`(C,)`:math:`(N, C)`:math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。
- **labels** (Tensor) - 输入目标值。若目标值为类别索引则shape为 :math:`()`:math:`(N)`:math:`(N, d_1, d_2, ..., d_K)` 数据类型仅支持int32。
若目标值为类别概率则shape为 :math:`(C,)`:math:`(N, C)`:math:`(N, C, d_1, d_2, ..., d_K)` 数据类型仅支持float32或float16。
返回:
Tensor一个数据类型与logits相同的Tensor。

View File

@ -1917,6 +1917,41 @@ class NLLLoss(LossBase):
return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction)
@constexpr
def _check_cross_entropy_inputs(logits_shape, label_shape, \
logits_rank, label_rank, \
logits_dtype, label_dtype, \
prim_name=None):
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
validator.check_type_name('logits', logits_dtype, [mstype.float16, mstype.float32], prim_name)
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
if logits_rank == label_rank:
validator.check_type_name('labels', label_dtype, [mstype.float16, mstype.float32], prim_name)
if logits_shape != label_shape:
raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), "
f"and the shape of 'labels' should be (N, C, d_0, d_1, ...), "
f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.")
elif label_rank == logits_rank - 1:
validator.check_type_name('labels', label_dtype, [mstype.int32], prim_name)
logits_shape_new = (logits_shape[0], *logits_shape[2:])
if logits_shape_new != label_shape:
raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), "
f"and the shape of 'labels' should be (N, d_0, d_1, ...), "
f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.")
else:
raise ValueError(f"{msg_prefix} rank of 'logits' and 'labels' should be:\n"
f"1. 'logits.ndim == labels.ndim' for probabilities, \n"
f"2. 'logits.ndim - 1 == labels.ndim' for class indices, \n"
f"but get 'logits' rank: {logits_rank} and 'labels' rank: {label_rank}.")
@constexpr
def _cross_entropy_ignore_index_warning(prim_name):
"""Internal function, used to warrning when ignore_index > 0 for probabilities."""
log.warning(f"For \'{prim_name}\', 'ignore_index' does not work when 'labels' is Probability.")
class CrossEntropyLoss(LossBase):
r"""
The cross entropy loss between input and target.
@ -1979,9 +2014,12 @@ class CrossEntropyLoss(LossBase):
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`. Data type must be float16 or float32.
- **labels** (Tensor) -:math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for high-dimensional data.
- **logits** (Tensor) - Tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K),
where `C = number of classes`. Data type must be float16 or float32.
- **labels** (Tensor) - For class indices, tensor of shape :math:`()`, :math:`(N)` or
:math:`(N, d_1, d_2, ..., d_K)` , data type must be int32.
For probabilities, tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K),
data type must be float16 or float32.
Returns:
Tensor, the computed cross entropy loss value.
@ -2031,4 +2069,10 @@ class CrossEntropyLoss(LossBase):
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
_check_cross_entropy_inputs(logits.shape, labels.shape, \
logits.ndim, labels.ndim, \
logits.dtype, labels.dtype, \
self.cls_name)
if logits.ndim == labels.ndim and self.ignore_index > 0:
_cross_entropy_ignore_index_warning(self.cls_name)
return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing)

View File

@ -1354,6 +1354,10 @@ def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', la
if weight is None:
weight = _ones_like(inputs)
else:
broadcast_shape = [1 for _ in range(inputs.ndim)]
broadcast_shape[1] = weight.shape[0]
weight = weight.reshape(broadcast_shape)
if reduction == 'mean':
return -(inputs * target * weight).sum() / (inputs.size / n_classes)
@ -1421,6 +1425,8 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 4:
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 1:
ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing)
else:
n = inputs.shape[0]
c = inputs.shape[1]