add function.gaussian_nll_loss
This commit is contained in:
parent
9b33f045b6
commit
e63c57ef62
|
@ -62,6 +62,7 @@ mindspore.ops.function
|
|||
mindspore.ops.binary_cross_entropy
|
||||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.gaussian_nll_loss
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.mse_loss
|
||||
mindspore.ops.nll_loss
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
mindspore.ops.gaussian_nll_loss
|
||||
================================
|
||||
|
||||
.. py:class:: mindspore.ops.gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean')
|
||||
|
||||
服从高斯分布的负对数似然损失。
|
||||
|
||||
目标值被认为是高斯分布的采样,其中期望和方差通过神经网络来预测。对于以高斯分布为模型的Tensor `x` 和记录期望的Tensor `target` ,以及均为正数的方差Tensor `var` 来说,计算的loss为:
|
||||
|
||||
.. math::
|
||||
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
|
||||
\ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2}
|
||||
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
|
||||
|
||||
其中,:math:`eps` 用于 :math:`log` 的稳定性。在默认情况下,常数部分被忽略,除非 :math:`full=True`。如果 :math:`var` 和 :math:`x` 的shape不一致(出于同方差性的假设),那么它必须最后一个维度是1,或者具有更少的维度(其他维度相同),来获得正确的广播。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。`*` 代表着任意数量的额外维度。
|
||||
- **target** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `x` 具有相同shape,或者相同shape但有一个维度为1(以允许广播)。
|
||||
- **var** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `x` 具有相同shape,或者相同shape但有一个维度为1,或者少一个维度(以允许广播)。
|
||||
- **full** (bool,可选) - 指定损失函数中的常数部分。如果为True,则常数为 :math:`const = 0.5*log(2*pi)`。默认值:False。
|
||||
- **eps** (float,可选) - 用于提高log的稳定性,必须大于0。默认值:1e-6。
|
||||
- **reduction** (str,可选) - 指定应用于输出结果的计算方式,'none'、'mean'、'sum',默认值:'mean'。
|
||||
|
||||
返回:
|
||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `target` 不是Tensor。
|
||||
- **TypeError** - `var` 不是Tensor。
|
||||
- **TypeError** - `full` 不是bool。
|
||||
- **TypeError** - `eps` 不是float。
|
||||
- **ValueError** - `eps` 不是在[0, inf)区间的float。
|
||||
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
||||
|
||||
参考:
|
||||
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
|
||||
target probability distribution", Proceedings of 1994 IEEE International
|
||||
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
|
||||
vol.1, doi: 10.1109/ICNN.1994.374138.
|
|
@ -63,6 +63,7 @@ Loss Functions
|
|||
mindspore.ops.binary_cross_entropy
|
||||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.gaussian_nll_loss
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.mse_loss
|
||||
mindspore.ops.nll_loss
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""loss"""
|
||||
from __future__ import absolute_import, division
|
||||
|
||||
import math
|
||||
import mindspore
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
|
@ -2390,7 +2389,7 @@ class GaussianNLLLoss(LossBase):
|
|||
>>> var = Tensor(np.ones((4, 1)), mstype.float32)
|
||||
>>> output = loss(logits, labels, var)
|
||||
>>> print(output)
|
||||
Tensor(shape=[], dtype=Float32, value= 1.4375)
|
||||
1.4374993
|
||||
|
||||
Reference:
|
||||
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
|
||||
|
@ -2400,25 +2399,19 @@ class GaussianNLLLoss(LossBase):
|
|||
"""
|
||||
|
||||
def __init__(self, *, full=False, eps=1e-6, reduction='mean'):
|
||||
super(GaussianNLLLoss, self).__init__(reduction)
|
||||
super(GaussianNLLLoss, self).__init__()
|
||||
validator.check_float_range(eps, 0, float('inf'), Rel.INC_NEITHER, "eps", self.cls_name)
|
||||
validator.check_value_type('full', full, [bool], self.cls_name)
|
||||
validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss')
|
||||
self.full = full
|
||||
self.eps = eps
|
||||
self.max = P.Maximum()
|
||||
self.log = P.Log()
|
||||
self.square = P.Square()
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, logits, labels, var):
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('labels', labels, self.cls_name)
|
||||
_check_is_tensor('var', var, self.cls_name)
|
||||
maxima = self.max(var, self.eps)
|
||||
logarithm = self.log(maxima)
|
||||
squared_loss = self.square(logits - labels)
|
||||
c = 0 if not self.full else 0.5 * math.log(2 * math.pi)
|
||||
loss = 0.5 * (logarithm + squared_loss / maxima) + c
|
||||
return self.get_loss(loss)
|
||||
return ops.gaussian_nll_loss(logits, labels, var, self.full, self.eps, self.reduction)
|
||||
|
||||
|
||||
class HingeEmbeddingLoss(LossBase):
|
||||
|
@ -2480,7 +2473,7 @@ class HingeEmbeddingLoss(LossBase):
|
|||
>>> loss = nn.HingeEmbeddingLoss(reduction='mean')
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
Tensor(shape=[], dtype=Float32, value= 1.6666667)
|
||||
0.16666667
|
||||
"""
|
||||
|
||||
def __init__(self, margin=1.0, reduction='mean'):
|
||||
|
|
|
@ -372,6 +372,7 @@ from .nn_func import (
|
|||
elu,
|
||||
gelu,
|
||||
hinge_embedding_loss,
|
||||
gaussian_nll_loss,
|
||||
lp_pool1d,
|
||||
lp_pool2d,
|
||||
mse_loss,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""Defines nn operators with functional form."""
|
||||
from __future__ import absolute_import
|
||||
from math import pi
|
||||
from math import pi, log
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
@ -3245,6 +3245,100 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
|
|||
return (loss, log_alpha)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_gaussian_nll_loss(full, eps, reduction):
|
||||
validator.check_value_type('full', full, [bool], 'gaussian_nll_loss')
|
||||
validator.check_positive_float(eps, 'eps', 'gaussian_nll_loss')
|
||||
validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss')
|
||||
|
||||
|
||||
def gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean'):
|
||||
r"""
|
||||
Gaussian negative log likelihood loss.
|
||||
|
||||
The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the
|
||||
neural network. For a `target` tensor modelled as having Gaussian distribution with a tensor of expectations
|
||||
`x` and a tensor of positive variances `var` the loss is:
|
||||
|
||||
.. math::
|
||||
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
|
||||
\ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2}
|
||||
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
|
||||
|
||||
where `eps` is used for stability of :math:`log`. By default, the constant term of the loss function is omitted
|
||||
unless :math:`full=True`. If the shape of :math:`var` is not the same as `x` (due to a
|
||||
homoscedastic assumption), it must either have a final dimension of 1 or have one fewer dimension
|
||||
(with all other sizes being the same) for correct broadcasting.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of
|
||||
additional dimensions.
|
||||
target (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as the x, or same shape
|
||||
as the x but with one dimension equal to 1 (to allow broadcasting).
|
||||
var (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as x, or same shape as the x
|
||||
but with one dimension equal to 1, or same shape as the x but with one fewer dimension
|
||||
(to allow for broadcasting).
|
||||
full (bool): Include the constant term in the loss calculation. When :math:`full=True`, the constant term
|
||||
`const.` will be :math:`0.5 * log(2\pi)`. Default: False.
|
||||
eps (float): Used to improve the stability of log function must be greater than 0. Default: 1e-6.
|
||||
reduction (str): Apply specific reduction method to the output: 'none', 'mean', or 'sum'. Default: 'mean'.
|
||||
|
||||
Returns:
|
||||
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `target` is not a Tensor.
|
||||
TypeError: If `var` is not a Tensor.
|
||||
TypeError: If `full` is not a bool.
|
||||
TypeError: If `eps` is not a float.
|
||||
ValueError: If `eps` is not a float within [0, inf).
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> arr1 = np.arange(8).reshape((4, 2))
|
||||
>>> arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
|
||||
>>> x = Tensor(arr1, mstype.float32)
|
||||
>>> var = Tensor(np.ones((4, 1)), mstype.float32)
|
||||
>>> target = Tensor(arr2, mstype.float32)
|
||||
>>> output = ops.gaussian_nll_loss(x, target, var)
|
||||
>>> print(output)
|
||||
|
||||
Reference:
|
||||
Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
|
||||
target probability distribution", Proceedings of 1994 IEEE International
|
||||
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
|
||||
vol.1, doi: 10.1109/ICNN.1994.374138.
|
||||
"""
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f"For 'gaussian_nll_loss', 'x' must be a tensor, but got {type(x)}.")
|
||||
if not isinstance(target, Tensor):
|
||||
raise TypeError(f"For 'gaussian_nll_loss', 'target' must be a tensor, but got {type(target)}.")
|
||||
if not isinstance(var, Tensor):
|
||||
raise TypeError(f"For 'gaussian_nll_loss', 'var' must be a tensor, but got {type(var)}.")
|
||||
_check_gaussian_nll_loss(full, eps, reduction)
|
||||
max_op = P.Maximum()
|
||||
log_op = P.Log()
|
||||
square_op = P.Square()
|
||||
maxima = max_op(var, eps)
|
||||
logarithm = log_op(maxima)
|
||||
squared_loss = square_op(x - target)
|
||||
c = 0 if not full else 0.5 * log(2 * pi)
|
||||
loss = 0.5 * (logarithm + squared_loss / maxima) + c
|
||||
if reduction == 'mean':
|
||||
loss = loss.mean()
|
||||
elif reduction == 'sum':
|
||||
loss = loss.sum()
|
||||
return loss
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_hinge_embedding_loss(shape, shape2, prim_name):
|
||||
if shape2 != shape:
|
||||
|
@ -4718,6 +4812,7 @@ __all__ = [
|
|||
'elu',
|
||||
'gelu',
|
||||
'hinge_embedding_loss',
|
||||
'gaussian_nll_loss',
|
||||
'lp_pool1d',
|
||||
'lp_pool2d',
|
||||
'max_unpool1d',
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, full=False, reduction='mean'):
|
||||
super(Net, self).__init__()
|
||||
self.full = full
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, x, label, v):
|
||||
loss = ops.gaussian_nll_loss(x, label, v, full=self.full, reduction=self.reduction)
|
||||
return loss
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('full', [True, False])
|
||||
def test_gaussian_nll_loss_full(mode, full):
|
||||
"""
|
||||
Feature: GaussianNLLLoss with reduction='mean'
|
||||
Description: Verify the result of GaussianNLLLoss
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(full=full)
|
||||
arr1 = np.arange(8).reshape((4, 2))
|
||||
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
var = Tensor(np.ones((4, 1)), mstype.float32)
|
||||
output = net(a, b, var)
|
||||
if full:
|
||||
expected = np.array(2.35644, np.float32)
|
||||
else:
|
||||
expected = np.array(1.4375, np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('reduction', ['mean', 'sum', 'none'])
|
||||
def test_gaussian_nll_loss_reduction(mode, reduction):
|
||||
"""
|
||||
Feature: GaussianNLLLoss with full=False
|
||||
Description: Verify the result of GaussianNLLLoss
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(reduction=reduction)
|
||||
arr1 = np.arange(8).reshape((4, 2))
|
||||
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
var = Tensor(np.ones((4, 1)), mstype.float32)
|
||||
output = net(a, b, var)
|
||||
if reduction == 'mean':
|
||||
expected = np.array(1.4375, np.float32)
|
||||
elif reduction == 'sum':
|
||||
expected = np.array(11.5, np.float32)
|
||||
else:
|
||||
expected = np.array([[2.0000, 2.0000], [0.5000, 0.5000],
|
||||
[2.0000, 0.5000], [2.0000, 2.0000]], np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
|
@ -0,0 +1,23 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_gaussian_nll_loss_abnormal_full(mode):
|
||||
"""
|
||||
Feature: gaussian_nll_loss
|
||||
Description: Verify abnormal inputs of gaussian_nll_loss
|
||||
Expectation: raise TypeError
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
arr1 = np.arange(8).reshape((4, 2))
|
||||
arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
var = Tensor(np.ones((4, 1)), mstype.float32)
|
||||
with pytest.raises(TypeError):
|
||||
ops.gaussian_nll_loss(a, b, var, full=1, eps=1e-6, reduction='mean')
|
Loading…
Reference in New Issue