From e63c57ef62203f4c5b3275ca35c087fe303fe9b4 Mon Sep 17 00:00:00 2001 From: shaojunsong Date: Tue, 8 Nov 2022 09:36:30 +0800 Subject: [PATCH] add function.gaussian_nll_loss --- .../api/api_python/mindspore.ops.function.rst | 1 + .../mindspore.ops.func_gaussian_nll_loss.rst | 41 ++++++++ .../api_python_en/mindspore.ops.function.rst | 1 + mindspore/python/mindspore/nn/loss/loss.py | 19 ++-- .../python/mindspore/ops/function/__init__.py | 1 + .../python/mindspore/ops/function/nn_func.py | 97 ++++++++++++++++++- tests/st/ops/test_gaussian_nll_loss.py | 81 ++++++++++++++++ tests/ut/python/ops/test_gaussian_nll_loss.py | 23 +++++ 8 files changed, 250 insertions(+), 14 deletions(-) create mode 100644 docs/api/api_python/ops/mindspore.ops.func_gaussian_nll_loss.rst create mode 100644 tests/st/ops/test_gaussian_nll_loss.py create mode 100644 tests/ut/python/ops/test_gaussian_nll_loss.py diff --git a/docs/api/api_python/mindspore.ops.function.rst b/docs/api/api_python/mindspore.ops.function.rst index d7f8b937ad7..e281645e9d0 100644 --- a/docs/api/api_python/mindspore.ops.function.rst +++ b/docs/api/api_python/mindspore.ops.function.rst @@ -62,6 +62,7 @@ mindspore.ops.function mindspore.ops.binary_cross_entropy mindspore.ops.binary_cross_entropy_with_logits mindspore.ops.cross_entropy + mindspore.ops.gaussian_nll_loss mindspore.ops.hinge_embedding_loss mindspore.ops.mse_loss mindspore.ops.nll_loss diff --git a/docs/api/api_python/ops/mindspore.ops.func_gaussian_nll_loss.rst b/docs/api/api_python/ops/mindspore.ops.func_gaussian_nll_loss.rst new file mode 100644 index 00000000000..e24bdc6e7d2 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_gaussian_nll_loss.rst @@ -0,0 +1,41 @@ +mindspore.ops.gaussian_nll_loss +================================ + +.. py:class:: mindspore.ops.gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean') + + 服从高斯分布的负对数似然损失。 + + 目标值被认为是高斯分布的采样,其中期望和方差通过神经网络来预测。对于以高斯分布为模型的Tensor `x` 和记录期望的Tensor `target` ,以及均为正数的方差Tensor `var` 来说,计算的loss为: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + 其中,:math:`eps` 用于 :math:`log` 的稳定性。在默认情况下,常数部分被忽略,除非 :math:`full=True`。如果 :math:`var` 和 :math:`x` 的shape不一致(出于同方差性的假设),那么它必须最后一个维度是1,或者具有更少的维度(其他维度相同),来获得正确的广播。 + + 参数: + - **x** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。`*` 代表着任意数量的额外维度。 + - **target** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `x` 具有相同shape,或者相同shape但有一个维度为1(以允许广播)。 + - **var** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `x` 具有相同shape,或者相同shape但有一个维度为1,或者少一个维度(以允许广播)。 + - **full** (bool,可选) - 指定损失函数中的常数部分。如果为True,则常数为 :math:`const = 0.5*log(2*pi)`。默认值:False。 + - **eps** (float,可选) - 用于提高log的稳定性,必须大于0。默认值:1e-6。 + - **reduction** (str,可选) - 指定应用于输出结果的计算方式,'none'、'mean'、'sum',默认值:'mean'。 + + 返回: + Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。 + + 异常: + - **TypeError** - `x` 不是Tensor。 + - **TypeError** - `target` 不是Tensor。 + - **TypeError** - `var` 不是Tensor。 + - **TypeError** - `full` 不是bool。 + - **TypeError** - `eps` 不是float。 + - **ValueError** - `eps` 不是在[0, inf)区间的float。 + - **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。 + + 参考: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. diff --git a/docs/api/api_python_en/mindspore.ops.function.rst b/docs/api/api_python_en/mindspore.ops.function.rst index 2beaff3ed40..1291cbac59d 100644 --- a/docs/api/api_python_en/mindspore.ops.function.rst +++ b/docs/api/api_python_en/mindspore.ops.function.rst @@ -63,6 +63,7 @@ Loss Functions mindspore.ops.binary_cross_entropy mindspore.ops.binary_cross_entropy_with_logits mindspore.ops.cross_entropy + mindspore.ops.gaussian_nll_loss mindspore.ops.hinge_embedding_loss mindspore.ops.mse_loss mindspore.ops.nll_loss diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index 3bb3f93f433..783124e89b5 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -15,7 +15,6 @@ """loss""" from __future__ import absolute_import, division -import math import mindspore import mindspore.common.dtype as mstype import mindspore.ops as ops @@ -2390,7 +2389,7 @@ class GaussianNLLLoss(LossBase): >>> var = Tensor(np.ones((4, 1)), mstype.float32) >>> output = loss(logits, labels, var) >>> print(output) - Tensor(shape=[], dtype=Float32, value= 1.4375) + 1.4374993 Reference: Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the @@ -2400,25 +2399,19 @@ class GaussianNLLLoss(LossBase): """ def __init__(self, *, full=False, eps=1e-6, reduction='mean'): - super(GaussianNLLLoss, self).__init__(reduction) + super(GaussianNLLLoss, self).__init__() validator.check_float_range(eps, 0, float('inf'), Rel.INC_NEITHER, "eps", self.cls_name) validator.check_value_type('full', full, [bool], self.cls_name) + validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss') self.full = full self.eps = eps - self.max = P.Maximum() - self.log = P.Log() - self.square = P.Square() + self.reduction = reduction def construct(self, logits, labels, var): _check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('var', var, self.cls_name) - maxima = self.max(var, self.eps) - logarithm = self.log(maxima) - squared_loss = self.square(logits - labels) - c = 0 if not self.full else 0.5 * math.log(2 * math.pi) - loss = 0.5 * (logarithm + squared_loss / maxima) + c - return self.get_loss(loss) + return ops.gaussian_nll_loss(logits, labels, var, self.full, self.eps, self.reduction) class HingeEmbeddingLoss(LossBase): @@ -2480,7 +2473,7 @@ class HingeEmbeddingLoss(LossBase): >>> loss = nn.HingeEmbeddingLoss(reduction='mean') >>> output = loss(logits, labels) >>> print(output) - Tensor(shape=[], dtype=Float32, value= 1.6666667) + 0.16666667 """ def __init__(self, margin=1.0, reduction='mean'): diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 943430c9983..f626fa24f18 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -372,6 +372,7 @@ from .nn_func import ( elu, gelu, hinge_embedding_loss, + gaussian_nll_loss, lp_pool1d, lp_pool2d, mse_loss, diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 78e5ad91450..f58576952f6 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -15,7 +15,7 @@ """Defines nn operators with functional form.""" from __future__ import absolute_import -from math import pi +from math import pi, log import mindspore.ops as ops from mindspore.ops.primitive import constexpr @@ -3245,6 +3245,100 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti return (loss, log_alpha) +@constexpr +def _check_gaussian_nll_loss(full, eps, reduction): + validator.check_value_type('full', full, [bool], 'gaussian_nll_loss') + validator.check_positive_float(eps, 'eps', 'gaussian_nll_loss') + validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', 'gaussian_nll_loss') + + +def gaussian_nll_loss(x, target, var, full=False, eps=1e-6, reduction='mean'): + r""" + Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with expectations and variances predicted by the + neural network. For a `target` tensor modelled as having Gaussian distribution with a tensor of expectations + `x` and a tensor of positive variances `var` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{x} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where `eps` is used for stability of :math:`log`. By default, the constant term of the loss function is omitted + unless :math:`full=True`. If the shape of :math:`var` is not the same as `x` (due to a + homoscedastic assumption), it must either have a final dimension of 1 or have one fewer dimension + (with all other sizes being the same) for correct broadcasting. + + Args: + x (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of + additional dimensions. + target (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as the x, or same shape + as the x but with one dimension equal to 1 (to allow broadcasting). + var (Tensor): Tensor of shape :math:`(N, *)` or :math:`(*)`, same shape as x, or same shape as the x + but with one dimension equal to 1, or same shape as the x but with one fewer dimension + (to allow for broadcasting). + full (bool): Include the constant term in the loss calculation. When :math:`full=True`, the constant term + `const.` will be :math:`0.5 * log(2\pi)`. Default: False. + eps (float): Used to improve the stability of log function must be greater than 0. Default: 1e-6. + reduction (str): Apply specific reduction method to the output: 'none', 'mean', or 'sum'. Default: 'mean'. + + Returns: + Tensor or Tensor scalar, the computed loss depending on `reduction`. + + Raises: + TypeError: If `x` is not a Tensor. + TypeError: If `target` is not a Tensor. + TypeError: If `var` is not a Tensor. + TypeError: If `full` is not a bool. + TypeError: If `eps` is not a float. + ValueError: If `eps` is not a float within [0, inf). + ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.ops as ops + >>> import mindspore.common.dtype as mstype + >>> arr1 = np.arange(8).reshape((4, 2)) + >>> arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2)) + >>> x = Tensor(arr1, mstype.float32) + >>> var = Tensor(np.ones((4, 1)), mstype.float32) + >>> target = Tensor(arr2, mstype.float32) + >>> output = ops.gaussian_nll_loss(x, target, var) + >>> print(output) + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + if not isinstance(x, Tensor): + raise TypeError(f"For 'gaussian_nll_loss', 'x' must be a tensor, but got {type(x)}.") + if not isinstance(target, Tensor): + raise TypeError(f"For 'gaussian_nll_loss', 'target' must be a tensor, but got {type(target)}.") + if not isinstance(var, Tensor): + raise TypeError(f"For 'gaussian_nll_loss', 'var' must be a tensor, but got {type(var)}.") + _check_gaussian_nll_loss(full, eps, reduction) + max_op = P.Maximum() + log_op = P.Log() + square_op = P.Square() + maxima = max_op(var, eps) + logarithm = log_op(maxima) + squared_loss = square_op(x - target) + c = 0 if not full else 0.5 * log(2 * pi) + loss = 0.5 * (logarithm + squared_loss / maxima) + c + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + return loss + + @constexpr def _check_hinge_embedding_loss(shape, shape2, prim_name): if shape2 != shape: @@ -4718,6 +4812,7 @@ __all__ = [ 'elu', 'gelu', 'hinge_embedding_loss', + 'gaussian_nll_loss', 'lp_pool1d', 'lp_pool2d', 'max_unpool1d', diff --git a/tests/st/ops/test_gaussian_nll_loss.py b/tests/st/ops/test_gaussian_nll_loss.py new file mode 100644 index 00000000000..64dbcf48623 --- /dev/null +++ b/tests/st/ops/test_gaussian_nll_loss.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest +import mindspore.common.dtype as mstype +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + + +class Net(nn.Cell): + def __init__(self, full=False, reduction='mean'): + super(Net, self).__init__() + self.full = full + self.reduction = reduction + + def construct(self, x, label, v): + loss = ops.gaussian_nll_loss(x, label, v, full=self.full, reduction=self.reduction) + return loss + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('full', [True, False]) +def test_gaussian_nll_loss_full(mode, full): + """ + Feature: GaussianNLLLoss with reduction='mean' + Description: Verify the result of GaussianNLLLoss + Expectation: success + """ + context.set_context(mode=mode) + net = Net(full=full) + arr1 = np.arange(8).reshape((4, 2)) + arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2)) + a = Tensor(arr1, mstype.float32) + b = Tensor(arr2, mstype.float32) + var = Tensor(np.ones((4, 1)), mstype.float32) + output = net(a, b, var) + if full: + expected = np.array(2.35644, np.float32) + else: + expected = np.array(1.4375, np.float32) + assert np.allclose(output.asnumpy(), expected) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('reduction', ['mean', 'sum', 'none']) +def test_gaussian_nll_loss_reduction(mode, reduction): + """ + Feature: GaussianNLLLoss with full=False + Description: Verify the result of GaussianNLLLoss + Expectation: success + """ + context.set_context(mode=mode) + net = Net(reduction=reduction) + arr1 = np.arange(8).reshape((4, 2)) + arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2)) + a = Tensor(arr1, mstype.float32) + b = Tensor(arr2, mstype.float32) + var = Tensor(np.ones((4, 1)), mstype.float32) + output = net(a, b, var) + if reduction == 'mean': + expected = np.array(1.4375, np.float32) + elif reduction == 'sum': + expected = np.array(11.5, np.float32) + else: + expected = np.array([[2.0000, 2.0000], [0.5000, 0.5000], + [2.0000, 0.5000], [2.0000, 2.0000]], np.float32) + assert np.allclose(output.asnumpy(), expected) diff --git a/tests/ut/python/ops/test_gaussian_nll_loss.py b/tests/ut/python/ops/test_gaussian_nll_loss.py new file mode 100644 index 00000000000..f930dea04d0 --- /dev/null +++ b/tests/ut/python/ops/test_gaussian_nll_loss.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest +import mindspore.common.dtype as mstype +import mindspore.ops as ops +from mindspore import Tensor +from mindspore import context + + +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_gaussian_nll_loss_abnormal_full(mode): + """ + Feature: gaussian_nll_loss + Description: Verify abnormal inputs of gaussian_nll_loss + Expectation: raise TypeError + """ + context.set_context(mode=mode) + arr1 = np.arange(8).reshape((4, 2)) + arr2 = np.array([2, 3, 1, 4, 6, 4, 4, 9]).reshape((4, 2)) + a = Tensor(arr1, mstype.float32) + b = Tensor(arr2, mstype.float32) + var = Tensor(np.ones((4, 1)), mstype.float32) + with pytest.raises(TypeError): + ops.gaussian_nll_loss(a, b, var, full=1, eps=1e-6, reduction='mean')