!39904 fix docs and add check for nn.CrossEntropyLoss
Merge pull request !39904 from 吕昱峰(Nate.River)/master
This commit is contained in:
commit
ed1e4cc7b9
|
@ -55,10 +55,9 @@ mindspore.nn.CrossEntropyLoss
|
|||
- **label_smoothing** (float): 标签平滑值,用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。
|
||||
|
||||
输入:
|
||||
- **logits** (Tensor) - 输入预测值,shape为 :math:`(N, C)` 或 :math:`(N, C, H, W)`
|
||||
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。
|
||||
- **labels** (Tensor) - 输入目标值,shape为 :math:`(N)` 或 :math:`(N, d_1, d_2, ..., d_K)`
|
||||
(针对高维数据)。
|
||||
- **logits** (Tensor) - 输入预测值,shape为 :math:`(C,)` 、 :math:`(N, C)` 或 :math:`(N, C, d_1, d_2, ..., d_K)` (针对高维数据)。输入值需为对数概率。数据类型仅支持float32或float16。
|
||||
- **labels** (Tensor) - 输入目标值。若目标值为类别索引,则shape为 :math:`()` 、 :math:`(N)` 或 :math:`(N, d_1, d_2, ..., d_K)` ,数据类型仅支持int32。
|
||||
若目标值为类别概率,则shape为 :math:`(C,)` 、 :math:`(N, C)` 或 :math:`(N, C, d_1, d_2, ..., d_K)` ,数据类型仅支持float32或float16。
|
||||
|
||||
返回:
|
||||
Tensor,一个数据类型与logits相同的Tensor。
|
||||
|
|
|
@ -1917,6 +1917,41 @@ class NLLLoss(LossBase):
|
|||
return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_cross_entropy_inputs(logits_shape, label_shape, \
|
||||
logits_rank, label_rank, \
|
||||
logits_dtype, label_dtype, \
|
||||
prim_name=None):
|
||||
"""Internal function, used to check whether the shape of logits and labels meets the requirements."""
|
||||
validator.check_type_name('logits', logits_dtype, [mstype.float16, mstype.float32], prim_name)
|
||||
|
||||
msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
|
||||
if logits_rank == label_rank:
|
||||
validator.check_type_name('labels', label_dtype, [mstype.float16, mstype.float32], prim_name)
|
||||
if logits_shape != label_shape:
|
||||
raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), "
|
||||
f"and the shape of 'labels' should be (N, C, d_0, d_1, ...), "
|
||||
f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.")
|
||||
elif label_rank == logits_rank - 1:
|
||||
validator.check_type_name('labels', label_dtype, [mstype.int32], prim_name)
|
||||
logits_shape_new = (logits_shape[0], *logits_shape[2:])
|
||||
if logits_shape_new != label_shape:
|
||||
raise ValueError(f"{msg_prefix} shape of 'logits' should be (N, C, d_0, d_1, ...), "
|
||||
f"and the shape of 'labels' should be (N, d_0, d_1, ...), "
|
||||
f"but get 'logits' shape: {logits_shape} and 'labels' shape: {label_shape}.")
|
||||
else:
|
||||
raise ValueError(f"{msg_prefix} rank of 'logits' and 'labels' should be:\n"
|
||||
f"1. 'logits.ndim == labels.ndim' for probabilities, \n"
|
||||
f"2. 'logits.ndim - 1 == labels.ndim' for class indices, \n"
|
||||
f"but get 'logits' rank: {logits_rank} and 'labels' rank: {label_rank}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _cross_entropy_ignore_index_warning(prim_name):
|
||||
"""Internal function, used to warrning when ignore_index > 0 for probabilities."""
|
||||
log.warning(f"For \'{prim_name}\', 'ignore_index' does not work when 'labels' is Probability.")
|
||||
|
||||
|
||||
class CrossEntropyLoss(LossBase):
|
||||
r"""
|
||||
The cross entropy loss between input and target.
|
||||
|
@ -1979,9 +2014,12 @@ class CrossEntropyLoss(LossBase):
|
|||
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0.
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
|
||||
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`. Data type must be float16 or float32.
|
||||
- **labels** (Tensor) -:math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for high-dimensional data.
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K),
|
||||
where `C = number of classes`. Data type must be float16 or float32.
|
||||
- **labels** (Tensor) - For class indices, tensor of shape :math:`()`, :math:`(N)` or
|
||||
:math:`(N, d_1, d_2, ..., d_K)` , data type must be int32.
|
||||
For probabilities, tensor of shape :math:`(C,)` :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K),
|
||||
data type must be float16 or float32.
|
||||
|
||||
Returns:
|
||||
Tensor, the computed cross entropy loss value.
|
||||
|
@ -2031,4 +2069,10 @@ class CrossEntropyLoss(LossBase):
|
|||
def construct(self, logits, labels):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
_check_cross_entropy_inputs(logits.shape, labels.shape, \
|
||||
logits.ndim, labels.ndim, \
|
||||
logits.dtype, labels.dtype, \
|
||||
self.cls_name)
|
||||
if logits.ndim == labels.ndim and self.ignore_index > 0:
|
||||
_cross_entropy_ignore_index_warning(self.cls_name)
|
||||
return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing)
|
||||
|
|
|
@ -1354,6 +1354,10 @@ def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', la
|
|||
|
||||
if weight is None:
|
||||
weight = _ones_like(inputs)
|
||||
else:
|
||||
broadcast_shape = [1 for _ in range(inputs.ndim)]
|
||||
broadcast_shape[1] = weight.shape[0]
|
||||
weight = weight.reshape(broadcast_shape)
|
||||
|
||||
if reduction == 'mean':
|
||||
return -(inputs * target * weight).sum() / (inputs.size / n_classes)
|
||||
|
@ -1421,6 +1425,8 @@ def nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', l
|
|||
ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing)
|
||||
elif ndim == 4:
|
||||
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
|
||||
elif ndim == 1:
|
||||
ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing)
|
||||
else:
|
||||
n = inputs.shape[0]
|
||||
c = inputs.shape[1]
|
||||
|
|
Loading…
Reference in New Issue