add nll_loss and cross_entropy for ops.functional and nn

This commit is contained in:
lvyufeng 2022-05-30 10:27:57 +08:00
parent ccd3bfb443
commit ee30ffc3bf
13 changed files with 732 additions and 2 deletions

View File

@ -207,12 +207,14 @@ Dropout层
mindspore.nn.BCELoss
mindspore.nn.BCEWithLogitsLoss
mindspore.nn.CosineEmbeddingLoss
mindspore.nn.CrossEntropyLoss
mindspore.nn.DiceLoss
mindspore.nn.FocalLoss
mindspore.nn.HuberLoss
mindspore.nn.L1Loss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.NLLLoss
mindspore.nn.RMSELoss
mindspore.nn.SampledSoftmaxLoss
mindspore.nn.SmoothL1Loss

View File

@ -58,6 +58,17 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
mindspore.ops.hardshrink
mindspore.ops.tanh
损失函数
^^^^^^^^^^
.. mscnplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.cross_entropy
mindspore.ops.nll_loss
数学运算算子
----------------

View File

@ -0,0 +1,78 @@
mindspore.nn.CrossEntropyLoss
=============================
.. py:class:: mindspore.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
计算预测值和目标值之间的交叉熵损失。
cross_entropy方法支持两种不同的目标值(target):
- 类别索引 (int),取值范围为:math:`[0, C)` 其中 :math:`C` 为类别数当reduction为'none'时,交叉熵损失公式如下:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
- 类别概率 (float)用于目标值为多个类别标签的情况。当reduction为'none'时,交叉熵损失公式如下:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
**参数:**
- **weight** (Tensor): 指定各类别的权重。若值不为None则shape为 (C,)。数据类型仅支持float32或float16。默认值: None。
- **ignore_index** (int): 指定target中需要忽略的值(一般为填充值),使其不对梯度产生影响。默认值: -100。
- **reduction** (string): 指定应用于输出结果的计算方式,比如'none'、'mean'、'sum'。默认值:'mean'。
- **label_smoothing** (float): 标签平滑值用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。
**输入:**
- **logits** (Tensor) - 输入预测值shape为 :math:`(N, C)`:math:`(N, C, H, W)`
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)`(针对高维数据)。
`inputs` 需为对数概率。数据类型仅支持float32或float16。
- **labels** (Tensor): 输入目标值shape为 :math:`(N)`:math:`(N, d_1, d_2, ..., d_K)`(针对高维数据)。
数据类型仅支持int32。
**输出:**
Tensor一个数据类型与logits相同的Tensor。
**异常:**
- **TypeError** - `weight` 不是Tensor。
- **TypeError** - `weight` 的dtype既不是float16也不是float32。
- **TypeError** - `ignore_index` 不是int。
- **ValueError** - `reduction` 不为"mean"、"sum",或"none"。
- **TypeError** - `label_smoothing` 不是float。
- **TypeError** - `logits` 不是Tensor。
- **TypeError** - `labels` 不是Tensor。

View File

@ -0,0 +1,51 @@
mindspore.nn.NLLLoss
====================
.. py:class:: mindspore.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
计算预测值和目标值之间的负对数似然损失。
reduction为'none'时,负对数似然损失公式如下:
.. math::
\ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
\quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
\quad w_{c}=\text { weight }[c] \cdot \mathbb{1}
\{c \not= \text{ignore\_index}\},
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, t)=L=\left\{\begin{array}{ll}
\sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean'; } \\
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
\end{array}\right.
**参数:**
- **weight** (Tensor): 指定各类别的权重。若值不为None则shape为 (C,)。数据类型仅支持float32或float16。默认值: None。
- **ignore_index** (int): 指定target中需要忽略的值(一般为填充值),使其不对梯度产生影响。默认值: -100。
- **reduction** (string): 指定应用于输出结果的计算方式,比如'none'、'mean'、'sum',默认值:'mean'。
**输入:**
- **logits** (Tensor) - 输入预测值shape为 :math:`(N, C)`:math:`(N, C, H, W)`
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)`(针对高维数据)。
`inputs` 需为对数概率。数据类型仅支持float32或float16。
- **labels** (Tensor): 输入目标值shape为 :math:`(N)`:math:`(N, d_1, d_2, ..., d_K)`(针对高维数据)。
数据类型仅支持int32。
**输出:**
Tensor一个数据类型与logits相同的Tensor。
**异常:**
- **TypeError** - `weight` 不是Tensor。
- **TypeError** - `weight` 的dtype既不是float16也不是float32。
- **TypeError** - `ignore_index` 不是int。
- **ValueError** - `reduction` 不为"mean"、"sum",或"none"。
- **TypeError** - `logits` 不是Tensor。
- **TypeError** - `labels` 不是Tensor。

View File

@ -0,0 +1,66 @@
mindspore.ops.cross_entropy
===========================
.. py:function:: mindspore.ops.cross_entropy(inputs, target, weight=None, ignore_index=None, reduction='mean', label_smoothing=0.0)
获取预测值和目标值之间的交叉熵损失。
cross_entropy方法支持两种不同的目标值(target):
- 类别索引 (int),取值范围为:math:`[0, C)` 其中 :math:`C` 为类别数当reduction为'none'时,交叉熵损失公式如下:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
- 类别概率 (float), 用于目标值为多个类别标签的情况。当reduction为'none'时,交叉熵损失公式如下:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
**参数:**
- **inputs** (Tensor) - 输入预测值shape为 :math:`(N, C)`:math:`(N, C, H, W)`
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)`(针对高维数据)。
`inputs` 需为对数概率。数据类型仅支持float32或float16。
- **target** (Tensor): 输入目标值shape为 :math:`(N)`:math:`(N, d_1, d_2, ..., d_K)`(针对高维数据)。
数据类型仅支持int32。
- **weight** (Tensor): 指定各类别的权重。若值不为None则shape为 (C,)。
数据类型仅支持float32或float16。默认值: None。
- **ignore_index** (int): 指定target中需要忽略的值(一般为填充值),使其不对梯度产生影响。默认值: -100。
- **reduction** (string): 指定应用于输出结果的计算方式,比如'none'、'mean'、'sum',默认值:'mean'。
- **label_smoothing** (float): 标签平滑值用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。
**返回:**
Tensor数据类型与 `inputs` 相同。

View File

@ -0,0 +1,41 @@
mindspore.ops.nll_loss
======================
.. py:function:: mindspore.ops.nll_loss(inputs, target, weight=None, ignore_index=None, reduction='mean', label_smoothing=0.0)
获取预测值和目标值之间的负对数似然损失。
reduction为'none'时,负对数似然损失公式如下:
.. math::
\ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
\quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
\quad w_{c}=\text { weight }[c] \cdot \mathbb{1}
\{c \not= \text{ignore\_index}\},
其中, :math:`x` 表示预测值, :math:`t` 表示目标值, :math:`w` 表示权重N表示batch size :math:`c` 限定范围为[0, C-1],表示类索引,其中 :math:`C` 表示类的数量。
若reduction不为'none'(默认为'mean'),则
.. math::
\ell(x, t)=L=\left\{\begin{array}{ll}
\sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean'; } \\
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
\end{array}\right.
**参数:**
- **inputs** (Tensor) - 输入预测值shape为 :math:`(N, C)`:math:`(N, C, H, W)`
(针对二维数据), 或 :math:`(N, C, d_1, d_2, ..., d_K)`(针对高维数据)。
`inputs` 需为对数概率。数据类型仅支持float32或float16。
- **target** (Tensor): 输入目标值shape为 :math:`(N)`:math:`(N, d_1, d_2, ..., d_K)`(针对高维数据)。
数据类型仅支持int32。
- **weight** (Tensor): 指定各类别的权重。若值不为None则shape为 (C,)。
数据类型仅支持float32或float16。默认值: None。
- **ignore_index** (int): 指定target中需要忽略的值(一般为填充值),使其不对梯度产生影响。默认值: -100。
- **reduction** (string): 指定应用于输出结果的计算方式,比如'none'、'mean'、'sum',默认值:"mean"。
- **label_smoothing** (float): 标签平滑值用于计算Loss时防止模型过拟合的正则化手段。取值范围为[0.0, 1.0]。 默认值: 0.0。
**返回:**
Tensor数据类型与 `inputs` 相同。

View File

@ -207,12 +207,14 @@ Loss Function
mindspore.nn.BCELoss
mindspore.nn.BCEWithLogitsLoss
mindspore.nn.CosineEmbeddingLoss
mindspore.nn.CrossEntropyLoss
mindspore.nn.DiceLoss
mindspore.nn.FocalLoss
mindspore.nn.HuberLoss
mindspore.nn.L1Loss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.NLLLoss
mindspore.nn.RMSELoss
mindspore.nn.SampledSoftmaxLoss
mindspore.nn.SmoothL1Loss

View File

@ -58,6 +58,17 @@ Activation Functions
mindspore.ops.hardshrink
mindspore.ops.tanh
Loss Functions
^^^^^^^^^^^^^^
.. msplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.cross_entropy
mindspore.ops.nll_loss
Mathematical Operators
----------------------

View File

@ -22,10 +22,10 @@ It shows how well the model works on a dataset and the optimization target which
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss, HuberLoss
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss', 'HuberLoss']
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss']

View File

@ -1590,3 +1590,192 @@ class HuberLoss(LossBase):
loss = self.select(condition, l1, l2)
return self.get_loss(loss)
class NLLLoss(LossBase):
r"""
Gets the negative log likelihood loss between logits and labels.
The nll loss with reduction=none can be described as:
.. math::
\ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
\quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
\quad w_{c}=\text { weight }[c] \cdot 1
where :math:`x` is the logits, :math:`t` is the labels, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, t)=\left\{\begin{array}{ll}
\sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean'; } \\
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
\end{array}\right.
Args:
weight (Tensor, optional): The rescaling weight to each class. If the value is not None, the shape is (C,).
The data type only supports float32 or float16. Default: None.
ignore_index (int, optional): Specifies a target value that is ignored (typically for padding value)
and does not contribute to the gradient. Default: -100.
reduction (string, optional): Apply specific reduction method to the output: 'none', 'mean', or 'sum'.
Default: 'mean'.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`(for high-dimensional data).
Data type must be float16 or float32. `inputs` needs to be logarithmic probability.
- **labels** (Tensor) -:math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for high-dimensional data.
Data type must be int32.
Returns:
Tensor, the computed negative log likelihood loss value.
Raises:
TypeError: If `weight` is not a Tensor.
TypeError: If `ignore_index` is not an int.
TypeError: If the data type of `weight` is not float16 or float32.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
TypeError: If `logits` is not a Tensor.
TypeError: If `labels` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Example::
>>> logits = mindspore.Tensor(np.random.randn(3, 5))
>>> labels = mindspore.Tensor(np.array([1, 0, 4]))
>>> loss = nn.NLLLoss()
>>> output = loss(logits, labels)
"""
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__(reduction)
validator.check_value_type('ignore_index', ignore_index, int, self.cls_name)
if weight is not None:
validator.check_value_type("weight", weight, [Tensor], self.cls_name)
validator.check_type_name('weight', weight.dtype, [mstype.float16, mstype.float32], self.cls_name)
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
return F.nll_loss(logits, labels, self.weight, self.ignore_index, self.reduction)
class CrossEntropyLoss(LossBase):
r"""
The cross entropy loss between input and target.
The cross entropy support two kind of targets:
- Class indices (int) in the range :math:`[0, C)` where :math:`C` is the number of classes,
the loss with reduction=none can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
- Probabilities (float) for each class, useful when labels beyond a single class per minibatch item
are required, the loss with reduction=none can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
Args:
weight (Tensor): The rescaling weight to each class. If the value is not None, the shape is (C,).
The data type only supports float32 or float16. Default: None.
ignore_index (int): Specifies a target value that is ignored (typically for padding value)
and does not contribute to the gradient. Default: -100.
reduction (string): Apply specific reduction method to the output: 'none', 'mean', or 'sum'.
Default: 'mean'.
label_smoothing (float): Label smoothing values, a regularization tool used to prevent the model
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0.
Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`. Data type must be float16 or float32.
- **labels** (Tensor) -:math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for high-dimensional data.
Returns:
Tensor, the computed cross entropy loss value.
Raises:
TypeError: If `weight` is not a Tensor.
TypeError: If `ignore_index` is not an int.
TypeError: If the data type of `weight` is not float16 or float32.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
TypeError: If `label_smoothing` is not a float.
TypeError: If `logits` is not a Tensor.
TypeError: If `labels` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Example::
>>> inputs = mindspore.Tensor(np.random.randn(3, 5))
>>> target = mindspore.Tensor(np.array([1, 0, 4]))
>>> loss = nn.CrossEntropy()
>>> output = loss(inputs, target)
"""
def __init__(self, weight=None, ignore_index=-100, reduction='mean',
label_smoothing=0.0):
super().__init__(reduction)
validator.check_value_type('ignore_index', ignore_index, int, self.cls_name)
validator.check_value_type('label_smoothing', label_smoothing, float, self.cls_name)
validator.check_float_range(label_smoothing, 0.0, 1.0, Rel.INC_BOTH, 'label_smoothing', self.cls_name)
if weight is not None:
validator.check_value_type("weight", weight, [Tensor], self.cls_name)
validator.check_type_name('weight', weight.dtype, [mstype.float16, mstype.float32], self.cls_name)
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
return F.cross_entropy(logits, labels, self.weight, self.ignore_index, self.reduction, self.label_smoothing)

View File

@ -192,6 +192,8 @@ from .nn_func import (
hardshrink,
softsign,
pdist,
nll_loss,
cross_entropy,
)
from .linalg_func import (
svd,

View File

@ -248,11 +248,226 @@ def pdist(x, p=2.0):
return pdist_(x)
def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
r"""
The cross entropy loss between input and target.
The cross entropy support two kind of targets:
- Class indices (int) in the range :math:`[0, C)` where :math:`C` is the number of classes,
the loss with reduction=none can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
\cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
- Probabilities (float) for each class, useful when labels beyond a single class per minibatch item
are required, the loss with reduction=none can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
Args:
inputs (Tensor): :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`.
`inputs` is expected to be log-probabilities, data type must be float16 or float32.
target (Tensor): :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for
high-dimensional loss.
weight (Tensor): A rescaling weight applied to the loss of each batch element.
If not None, the shape is :math:`(C,)`,
data type must be float16 or float32. Default: None.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient. Default: -100
reduction (string): Apply specific reduction method to the output: 'none', 'mean', or 'sum'.
Default: 'mean'.
label_smoothing (float): Label smoothing values, a regularization tool used to prevent the model
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0.
Returns:
Tensor, the computed loss value.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Example::
>>> inputs = mindspore.Tensor(np.random.randn(3, 5))
>>> target = mindspore.Tensor(np.array([1, 0, 4]))
>>> output = ops.cross_entropy(inputs, target)
"""
class_dim = 0 if inputs.ndim == 1 else 1
if inputs.size == target.size:
return _cross_entropy(inputs, target, class_dim, weight, reduction, label_smoothing)
return nll_loss(P.LogSoftmax(class_dim)(inputs), target, weight, ignore_index, reduction, label_smoothing)
def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0):
"""cross entropy inner function"""
class_dim = 0 if inputs.ndim == 1 else 1
n_classes = inputs.shape[class_dim]
inputs = P.LogSoftmax(target_dim)(inputs)
if label_smoothing > 0.0:
target = target * (1 - label_smoothing) + label_smoothing / n_classes
if weight is None:
weight = P.OnesLike()(inputs)
if reduction == 'mean':
return -(inputs * target * weight).sum() / (inputs.size / n_classes)
if reduction == 'sum':
return -(inputs * target * weight).sum()
return -(inputs * target * weight).sum(class_dim)
def nll_loss(inputs, target, weight=None, ignore_index=None, reduction='mean', label_smoothing=0.0):
r"""
Gets the negative log likelihood loss between inputs and target.
The nll loss with reduction=none can be described as:
.. math::
\ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
\quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
\quad w_{c}=\text { weight }[c] \cdot \mathbb{1}
\{c \not= \text{ignore\_index}\},
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
N is the batch size, :math:`c` belonging to [0, C-1] is class index, where :math:`C` is the number of classes.
If reduction is not 'none' (default 'mean'), then
.. math::
\ell(x, t)=\left\{\begin{array}{ll}
\sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean'; } \\
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
\end{array}\right.
Args:
inputs (Tensor): :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`.
`inputs` is expected to be log-probabilities, data type must be float16 or float32.
target (Tensor): :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` for
high-dimensional loss, data type must be int32.
weight (Tensor): A rescaling weight applied to the loss of each batch element.
If not None, the shape is :math:`(C,)`.
The data type must be float16 or float32. Default: None.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient. Default: -100
reduction (string): Apply specific reduction method to the output: 'none', 'mean', or 'sum'.
Default: 'mean'.
label_smoothing (float): Label smoothing values, a regularization tool used to prevent the model
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default value: 0.0.
Outputs:
Tensor, the computed loss value.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Example::
>>> inputs = mindspore.Tensor(np.random.randn(3, 5))
>>> target = mindspore.Tensor(np.array([1, 0, 4]))
>>> output = ops.nll_loss(inputs, target)
"""
ndim = inputs.ndim
if ndim == 2:
ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 4:
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
else:
n = inputs.shape[0]
c = inputs.shape[1]
out_size = (n,) + inputs.shape[2:]
inputs = inputs.view(n, c, 1, -1)
target = target.view(n, 1, -1)
if reduction != 'none':
ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing)
else:
ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing)
ret = ret.view(out_size)
return ret
def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
"""nll loss inner function"""
if target.ndim == inputs.ndim - 1:
target = target.expand_dims(target_dim)
loss = P.Neg()(P.GatherD()(inputs, target_dim, target))
smooth_loss = P.Neg()(inputs.sum(axis=target_dim, keepdims=True))
if weight is not None:
loss_weights = P.Gather()(weight, target, 0)
loss = loss * loss_weights
else:
loss_weights = P.OnesLike()(loss)
if ignore_index is not None:
non_pad_mask = P.Equal()(target, ignore_index)
loss = loss.masked_fill(non_pad_mask, 0.)
loss_weights = loss_weights.masked_fill(non_pad_mask, 0.)
smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.)
else:
loss = loss.squeeze(target_dim)
smooth_loss = smooth_loss.squeeze(target_dim)
if reduction == 'sum':
loss = loss.sum()
smooth_loss = smooth_loss.sum()
elif reduction == 'mean':
loss = loss.sum() / loss_weights.sum()
smooth_loss = smooth_loss.mean()
else:
loss = loss.sum(target_dim)
smooth_loss = smooth_loss.sum(target_dim)
eps_i = label_smoothing / inputs.shape[target_dim]
loss = (1. - label_smoothing) * loss + eps_i * smooth_loss
return loss
__all__ = [
'deformable_conv2d',
'fast_gelu',
'hardshrink',
'softsign',
'pdist',
'cross_entropy',
'nll_loss'
]
__all__.sort()

View File

@ -230,3 +230,65 @@ def test_huber_loss():
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
loss(input_data, target_data)
def test_cross_entropy_loss():
"""
Feature: Test CrossEntropyLoss.
Description: Test CrossEntropyLoss functional.
Expectation: Success.
"""
loss = nn.CrossEntropyLoss()
input_data = Tensor(np.random.randn(3, 5).astype(np.float32))
target_data = Tensor(np.array([1, 0, 4]).astype(np.int32))
loss(input_data, target_data)
def test_cross_entropy_loss_with_weight():
"""
Feature: Test CrossEntropyLoss.
Description: Test CrossEntropyLoss functional.
Expectation: Success.
"""
input_data = Tensor(np.random.randn(3, 5).astype(np.float32))
target_data = Tensor(np.array([1, 0, 4]).astype(np.int32))
weight_data = Tensor(np.array([0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32))
loss = nn.CrossEntropyLoss(weight=weight_data)
loss(input_data, target_data)
def test_nll_loss():
"""
Feature: Test NLLLoss.
Description: Test NLLLoss functional.
Expectation: Success.
"""
loss = nn.NLLLoss()
input_data = Tensor(np.random.randn(3, 5).astype(np.float32))
target_data = Tensor(np.array([1, 0, 4]).astype(np.int32))
loss(input_data, target_data)
def test_nll_loss_with_weight():
"""
Feature: Test NLLLoss.
Description: Test NLLLoss functional.
Expectation: Success.
"""
input_data = Tensor(np.random.randn(3, 5).astype(np.float32))
target_data = Tensor(np.array([1, 0, 4]).astype(np.int32))
weight_data = Tensor(np.array([0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32))
loss = nn.NLLLoss(weight=weight_data)
loss(input_data, target_data)
def test_nll_loss_4d():
"""
Feature: Test NLLLoss.
Description: Test NLLLoss functional.
Expectation: Success.
"""
loss = nn.NLLLoss()
input_data = Tensor(np.random.randn(3, 5, 1, 1).astype(np.float32))
target_data = Tensor(np.array([[[1]], [[0]], [[4]]]).astype(np.int32))
loss(input_data, target_data)