add function.gaussian_nll_loss

This commit is contained in:
shaojunsong 2022-11-08 09:36:30 +08:00
parent 9b33f045b6
commit e63c57ef62
8 changed files with 250 additions and 14 deletions

View File

@ -62,6 +62,7 @@ mindspore.ops.function
mindspore.ops.binary_cross_entropy
mindspore.ops.binary_cross_entropy_with_logits
mindspore.ops.cross_entropy
mindspore.ops.gaussian_nll_loss
mindspore.ops.hinge_embedding_loss
mindspore.ops.mse_loss
mindspore.ops.nll_loss

View File

@ -0,0 +1,41 @@
mindspore.ops.gaussian_nll_loss
================================
.. py:class:: mindspore.ops.gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean')
服从高斯分布的负对数似然损失。
目标值被认为是高斯分布的采样其中期望和方差通过神经网络来预测。对于以高斯分布为模型的Tensor `x` 和记录期望的Tensor `target` 以及均为正数的方差Tensor `var` 来说计算的loss为
.. math::
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2}
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
其中,:math:`eps` 用于 :math:`log` 的稳定性。在默认情况下,常数部分被忽略,除非 :math:`full=True`。如果 :math:`var`:math:`x` 的shape不一致出于同方差性的假设那么它必须最后一个维度是1或者具有更少的维度其他维度相同来获得正确的广播。
参数:
- **x** (Tensor) - shape为 :math:`(N, *)`:math:`(*)``*` 代表着任意数量的额外维度。
- **target** (Tensor) - shape为 :math:`(N, *)`:math:`(*)`。和 `x` 具有相同shape或者相同shape但有一个维度为1以允许广播
- **var** (Tensor) - shape为 :math:`(N, *)`:math:`(*)`。和 `x` 具有相同shape或者相同shape但有一个维度为1或者少一个维度以允许广播
- **full** (bool可选) - 指定损失函数中的常数部分。如果为True则常数为 :math:`const = 0.5*log(2*pi)`。默认值False。
- **eps** (float可选) - 用于提高log的稳定性必须大于0。默认值1e-6。
- **reduction** (str可选) - 指定应用于输出结果的计算方式,'none'、'mean'、'sum',默认值:'mean'。
返回:
Tensor或Tensor scalar根据 :math:`reduction` 计算的loss。
异常:
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `target` 不是Tensor。
- **TypeError** - `var` 不是Tensor。
- **TypeError** - `full` 不是bool。
- **TypeError** - `eps` 不是float。
- **ValueError** - `eps` 不是在[0, inf)区间的float。
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
参考:
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
target probability distribution", Proceedings of 1994 IEEE International
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
vol.1, doi: 10.1109/ICNN.1994.374138.

View File

@ -63,6 +63,7 @@ Loss Functions
mindspore.ops.binary_cross_entropy
mindspore.ops.binary_cross_entropy_with_logits
mindspore.ops.cross_entropy
mindspore.ops.gaussian_nll_loss
mindspore.ops.hinge_embedding_loss
mindspore.ops.mse_loss
mindspore.ops.nll_loss

View File

@ -15,7 +15,6 @@
"""loss"""
from __future__ import absolute_import, division
import math
import mindspore
import mindspore.common.dtype as mstype
import mindspore.ops as ops
@ -2390,7 +2389,7 @@ class GaussianNLLLoss(LossBase):
>>> var = Tensor(np.ones((4, 1)), mstype.float32)
>>> output = loss(logits, labels, var)
>>> print(output)
Tensor(shape=[], dtype=Float32, value= 1.4375)
1.4374993
Reference:
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
@ -2400,25 +2399,19 @@ class GaussianNLLLoss(LossBase):
"""
def __init__(self, *, full=False, eps=1e-6, reduction='mean'):
super(GaussianNLLLoss, self).__init__(reduction)
super(GaussianNLLLoss, self).__init__()
validator.check_float_range(eps, 0, float('inf'), Rel.INC_NEITHER, "eps", self.cls_name)
validator.check_value_type('full', full, [bool], self.cls_name)
validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss')
self.full = full
self.eps = eps
self.max = P.Maximum()
self.log = P.Log()
self.square = P.Square()
self.reduction = reduction
def construct(self, logits, labels, var):
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
_check_is_tensor('var', var, self.cls_name)
maxima = self.max(var, self.eps)
logarithm = self.log(maxima)
squared_loss = self.square(logits - labels)
c = 0 if not self.full else 0.5 * math.log(2 * math.pi)
loss = 0.5 * (logarithm + squared_loss / maxima) + c
return self.get_loss(loss)
return ops.gaussian_nll_loss(logits, labels, var, self.full, self.eps, self.reduction)
class HingeEmbeddingLoss(LossBase):
@ -2480,7 +2473,7 @@ class HingeEmbeddingLoss(LossBase):
>>> loss = nn.HingeEmbeddingLoss(reduction='mean')
>>> output = loss(logits, labels)
>>> print(output)
Tensor(shape=[], dtype=Float32, value= 1.6666667)
0.16666667
"""
def __init__(self, margin=1.0, reduction='mean'):

View File

@ -372,6 +372,7 @@ from .nn_func import (
elu,
gelu,
hinge_embedding_loss,
gaussian_nll_loss,
lp_pool1d,
lp_pool2d,
mse_loss,

View File

@ -15,7 +15,7 @@
"""Defines nn operators with functional form."""
from __future__ import absolute_import
from math import pi
from math import pi, log
import mindspore.ops as ops
from mindspore.ops.primitive import constexpr
@ -3245,6 +3245,100 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
return (loss, log_alpha)
@constexpr
def _check_gaussian_nll_loss(full, eps, reduction):
validator.check_value_type('full', full, [bool], 'gaussian_nll_loss')
validator.check_positive_float(eps, 'eps', 'gaussian_nll_loss')
validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss')
def gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean'):
r"""
Gaussian negative log likelihood loss.
The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the
neural network. For a `target` tensor modelled as having Gaussian distribution with a tensor of expectations
`x` and a tensor of positive variances `var` the loss is:
.. math::
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2}
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
where `eps` is used for stability of :math:`log`. By default, the constant term of the loss function is omitted
unless :math:`full=True`. If the shape of :math:`var` is not the same as `x` (due to a
homoscedastic assumption), it must either have a final dimension of 1 or have one fewer dimension
(with all other sizes being the same) for correct broadcasting.
Args:
x (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of
additional dimensions.
target (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as the x, or same shape
as the x but with one dimension equal to 1 (to allow broadcasting).
var (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as x, or same shape as the x
but with one dimension equal to 1, or same shape as the x but with one fewer dimension
(to allow for broadcasting).
full (bool): Include the constant term in the loss calculation. When :math:`full=True`, the constant term
`const.` will be :math:`0.5 * log(2\pi)`. Default: False.
eps (float): Used to improve the stability of log function must be greater than 0. Default: 1e-6.
reduction (str): Apply specific reduction method to the output: 'none', 'mean', or 'sum'. Default: 'mean'.
Returns:
Tensor or Tensor scalar, the computed loss depending on `reduction`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If `target` is not a Tensor.
TypeError: If `var` is not a Tensor.
TypeError: If `full` is not a bool.
TypeError: If `eps` is not a float.
ValueError: If `eps` is not a float within [0, inf).
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.ops as ops
>>> import mindspore.common.dtype as mstype
>>> arr1 = np.arange(8).reshape((4, 2))
>>> arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
>>> x = Tensor(arr1, mstype.float32)
>>> var = Tensor(np.ones((4, 1)), mstype.float32)
>>> target = Tensor(arr2, mstype.float32)
>>> output = ops.gaussian_nll_loss(x, target, var)
>>> print(output)
Reference:
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
target probability distribution", Proceedings of 1994 IEEE International
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
vol.1, doi: 10.1109/ICNN.1994.374138.
"""
if not isinstance(x, Tensor):
raise TypeError(f"For 'gaussian_nll_loss', 'x' must be a tensor, but got {type(x)}.")
if not isinstance(target, Tensor):
raise TypeError(f"For 'gaussian_nll_loss', 'target' must be a tensor, but got {type(target)}.")
if not isinstance(var, Tensor):
raise TypeError(f"For 'gaussian_nll_loss', 'var' must be a tensor, but got {type(var)}.")
_check_gaussian_nll_loss(full, eps, reduction)
max_op = P.Maximum()
log_op = P.Log()
square_op = P.Square()
maxima = max_op(var, eps)
logarithm = log_op(maxima)
squared_loss = square_op(x - target)
c = 0 if not full else 0.5 * log(2 * pi)
loss = 0.5 * (logarithm + squared_loss / maxima) + c
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
return loss
@constexpr
def _check_hinge_embedding_loss(shape, shape2, prim_name):
if shape2 != shape:
@ -4718,6 +4812,7 @@ __all__ = [
'elu',
'gelu',
'hinge_embedding_loss',
'gaussian_nll_loss',
'lp_pool1d',
'lp_pool2d',
'max_unpool1d',

View File

@ -0,0 +1,81 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def __init__(self, full=False, reduction='mean'):
super(Net, self).__init__()
self.full = full
self.reduction = reduction
def construct(self, x, label, v):
loss = ops.gaussian_nll_loss(x, label, v, full=self.full, reduction=self.reduction)
return loss
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('full', [True, False])
def test_gaussian_nll_loss_full(mode, full):
"""
Feature: GaussianNLLLoss with reduction='mean'
Description: Verify the result of GaussianNLLLoss
Expectation: success
"""
context.set_context(mode=mode)
net = Net(full=full)
arr1 = np.arange(8).reshape((4, 2))
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
a = Tensor(arr1, mstype.float32)
b = Tensor(arr2, mstype.float32)
var = Tensor(np.ones((4, 1)), mstype.float32)
output = net(a, b, var)
if full:
expected = np.array(2.35644, np.float32)
else:
expected = np.array(1.4375, np.float32)
assert np.allclose(output.asnumpy(), expected)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('reduction', ['mean', 'sum', 'none'])
def test_gaussian_nll_loss_reduction(mode, reduction):
"""
Feature: GaussianNLLLoss with full=False
Description: Verify the result of GaussianNLLLoss
Expectation: success
"""
context.set_context(mode=mode)
net = Net(reduction=reduction)
arr1 = np.arange(8).reshape((4, 2))
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
a = Tensor(arr1, mstype.float32)
b = Tensor(arr2, mstype.float32)
var = Tensor(np.ones((4, 1)), mstype.float32)
output = net(a, b, var)
if reduction == 'mean':
expected = np.array(1.4375, np.float32)
elif reduction == 'sum':
expected = np.array(11.5, np.float32)
else:
expected = np.array([[2.0000, 2.0000], [0.5000, 0.5000],
[2.0000, 0.5000], [2.0000, 2.0000]], np.float32)
assert np.allclose(output.asnumpy(), expected)

View File

@ -0,0 +1,23 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_gaussian_nll_loss_abnormal_full(mode):
"""
Feature: gaussian_nll_loss
Description: Verify abnormal inputs of gaussian_nll_loss
Expectation: raise TypeError
"""
context.set_context(mode=mode)
arr1 = np.arange(8).reshape((4, 2))
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
a = Tensor(arr1, mstype.float32)
b = Tensor(arr2, mstype.float32)
var = Tensor(np.ones((4, 1)), mstype.float32)
with pytest.raises(TypeError):
ops.gaussian_nll_loss(a, b, var, full=1, eps=1e-6, reduction='mean')