add ops.multilabel_soft_margin_loss

This commit is contained in:
fandawei 2022-12-28 15:10:05 +08:00
parent bb7bd1b878
commit 7f1e73a354
11 changed files with 274 additions and 24 deletions

View File

@ -249,6 +249,7 @@ Dropout层
mindspore.nn.MarginRankingLoss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.MultiLabelSoftMarginLoss
mindspore.nn.NLLLoss
mindspore.nn.RMSELoss
mindspore.nn.SampledSoftmaxLoss

View File

@ -45,6 +45,7 @@ mindspore.ops
mindspore.ops.max_unpool3d
mindspore.ops.multi_margin_loss
mindspore.ops.multi_label_margin_loss
mindspore.ops.multilabel_soft_margin_loss
mindspore.ops.kl_div
mindspore.ops.pad
mindspore.ops.padding

View File

@ -0,0 +1,28 @@
mindspore.nn.MultiLabelSoftMarginLoss
======================================
.. py:class:: mindspore.nn.MultiLabelSoftMarginLoss(weight=None, reduction='mean')
基于最大熵计算用于多标签优化的损失。计算公式如下。
.. math::
\mathcal{L}_{D} = - \frac{1}{|D|}\sum_{i = 0}^{|D|}\left(
y_{i}\ln\frac{1}{1 + e^{- x_{i}}} + \left( 1 - y_{i}
\right)\ln\frac{1}{1 + e^{x_{i}}} \right)
:math:`\mathcal{L}_{D}` 为损失值,:math:`y_{i}``target` ,
:math:`x_{i}``x` 。如果 `weight` 不为None将会和每个分类的loss相乘。
参数:
- **weight** (Union[Tensor, int, float]) - 每个类别的缩放权重。默认值None。
- **reduction** (str) - 指定应用于输出结果的计算方式。取值为"mean""sum",或"none"。默认值:"mean"。
输入:
- **x** (Tensor) - shape为(N, C)的TensorN为batch sizeC为类别个数。
- **target** (Tensor) - 目标值数据类型和shape与 `x` 的相同。
输出:
Tensor数据类型和 `x` 相同。如果 `reduction` 为"none"其shape为(N)。否则其shape为0。
异常:
- **TypeError** - `x``target` 的维度不等于2。

View File

@ -0,0 +1,26 @@
mindspore.ops.multilabel_soft_margin_loss
=========================================
.. py:function:: mindspore.ops.multilabel_soft_margin_loss(x, target, weight=None, reduction='mean')
基于最大熵计算用于多标签优化的损失。计算公式如下。
.. math::
\mathcal{L}_{D} = - \frac{1}{|D|}\sum_{i = 0}^{|D|}\left(
y_{i}\ln\frac{1}{1 + e^{- x_{i}}} + \left( 1 - y_{i}
\right)\ln\frac{1}{1 + e^{x_{i}}} \right)
:math:`\mathcal{L}_{D}` 为损失值,:math:`y_{i}``target` ,
:math:`x_{i}``x` 。如果 `weight` 不为None将会和每个分类的loss相乘。
参数:
- **x** (Tensor) - shape为(N, C)的TensorN为batch sizeC为类别个数。
- **target** (Tensor) - 目标值数据类型和shape与 `x` 的相同。
- **weight** (Union[Tensor, int, float]) - 每个类别的缩放权重。默认值None。
- **reduction** (str) - 指定应用于输出结果的计算方式。取值为"mean""sum",或"none"。默认值:"mean"。
返回:
Tensor数据类型和 `x` 相同。如果 `reduction` 为"none"其shape为(N)。否则其shape为0。
异常:
- **TypeError** - `x``target` 的维度不等于2。

View File

@ -249,6 +249,7 @@ Loss Function
mindspore.nn.MarginRankingLoss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.MultiLabelSoftMarginLoss
mindspore.nn.NLLLoss
mindspore.nn.RMSELoss
mindspore.nn.SampledSoftmaxLoss

View File

@ -45,6 +45,7 @@ Neural Network
mindspore.ops.max_unpool3d
mindspore.ops.multi_margin_loss
mindspore.ops.multi_label_margin_loss
mindspore.ops.multilabel_soft_margin_loss
mindspore.ops.kl_div
mindspore.ops.pad
mindspore.ops.padding

View File

@ -1269,10 +1269,10 @@ class MultiLabelSoftMarginLoss(LossBase):
:math:`x_{i}` is the `x`. `weight` will multiply to the loss of each class if given.
Args:
weight (int, float): The manual rescaling weight given to each class. Default: None.
reduction (string): Specifies which reduction to be applied to the output. It must be one of
weight (Union[Tensor, int, float]): The manual rescaling weight given to each class. Default: None.
reduction (str): Specifies which reduction to be applied to the output. It must be one of
'none', 'mean', and 'sum', meaning no reduction, reduce mean and sum on output, respectively.
Default 'mean'.
Default: 'mean'.
Inputs:
- **x** (Tensor) - A tensor of shape (N, C), where N is batch size and C is number
@ -1280,8 +1280,7 @@ class MultiLabelSoftMarginLoss(LossBase):
- **target** (Tensor) - The label target Tensor which has the same shape as `x`.
Outputs:
Tensor or Scalar, if `reduction` is 'none', the output is a tensor of shape (N) with the same data type as `x`.
Otherwise it is a scalar.
Tensor, the data type is the same as x, if the reduction is 'none', its shape is (N), otherwise it is zero.
Raises:
ValueError: If the rank of `x` or `target`is not 2.
@ -1302,27 +1301,10 @@ class MultiLabelSoftMarginLoss(LossBase):
"""Initialize MultiLabelSoftMarginLoss."""
super(MultiLabelSoftMarginLoss, self).__init__(reduction)
self.weight = weight
self.mul = P.Mul()
self.exp = P.Exp()
self.add = P.Add()
self.log = P.Log()
self.reduction = reduction
def construct(self, x, target):
_check_is_tensor('x', x, self.cls_name)
_check_is_tensor('target', target, self.cls_name)
if x.ndim != 2 or target.ndim != 2:
raise ValueError(
"For 'MultiLabelSoftMarginLoss', the inputs must be 2d tensor, but got shapes: "
f"x: {x.shape}, target: {target.shape} "
)
pos = self.log(self.add(self.exp(-x), 1))
neg = self.log(self.add(self.exp(x), 1))
loss = target * pos + (1 - target) * neg
if self.weight is not None:
loss = loss * self.weight
class_dim = x.ndim - 1
loss = loss.sum(axis=class_dim) / x.shape[class_dim]
return self.get_loss(loss)
return F.multilabel_soft_margin_loss(x, target, self.weight, self.reduction)
class MultiMarginLoss(LossBase):

View File

@ -437,6 +437,7 @@ from .nn_func import (
glu,
multi_margin_loss,
multi_label_margin_loss,
multilabel_soft_margin_loss,
elu,
gelu,
hinge_embedding_loss,

View File

@ -4710,6 +4710,68 @@ def multi_label_margin_loss(inputs, target, reduction='mean'):
return outputs
def multilabel_soft_margin_loss(x, target, weight=None, reduction='mean'):
r"""
Calculates the MultiLabelSoftMarginLoss.
Create a criterion for optimizing multi-label one-to-total loss based on maximum entropy.
.. math::
\mathcal{L}_{D} = - \frac{1}{|D|}\sum_{i = 0}^{|D|}\left(
y_{i}\ln\frac{1}{1 + e^{- x_{i}}} + \left( 1 - y_{i}
\right)\ln\frac{1}{1 + e^{x_{i}}} \right)
where :math:`\mathcal{L}_{D}` is the loss, :math:`y_{i}` is the `target`,
:math:`x_{i}` is the `x`. `weight` will multiply to the loss of each class if given.
Args:
x (Tensor): A tensor of shape (N, C), where N is batch size and C is number of classes.
target (Tensor): The label target Tensor which has the same shape as `x`.
weight (Union[Tensor, int, float]): The manual rescaling weight given to each class. Default: None.
reduction (str): Specifies which reduction to be applied to the output. It must be one of
'none', 'mean', and 'sum', meaning no reduction, reduce mean and sum on output, respectively.
Default: 'mean'.
Returns:
Tensor, the data type is the same as x, if the reduction is 'none', its shape is (N), otherwise it is zero.
Raises:
ValueError: If the rank of `x` or `target`is not 2.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[0.3, 0.6, 0.6], [0.9, 0.4, 0.2]])
>>> target = Tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])
>>> loss = ops.multilabel_soft_margin_loss(x, target, reduction='mean')
>>> out = loss(x, target)
>>> print(out.asnumpy())
0.84693956
"""
cls_name = "multilabel_soft_margin_loss"
_check_is_tensor('x', x, cls_name)
_check_is_tensor('target', target, cls_name)
if x.ndim != 2 or target.ndim != 2:
raise ValueError(
"For 'MultiLabelSoftMarginLoss', the inputs must be 2d tensor, but got shapes: "
f"x: {x.shape}, target: {target.shape} "
)
mul_op = _get_cache_prim(P.Mul)()
exp_op = _get_cache_prim(P.Exp)()
add_op = _get_cache_prim(P.Add)()
log_op = _get_cache_prim(P.Log)()
pos = log_op(add_op(exp_op(-x), 1))
neg = log_op(add_op(exp_op(x), 1))
loss = mul_op(target, pos) + mul_op(1 - target, neg)
if weight is not None:
loss = mul_op(loss, weight)
class_dim = x.ndim - 1
loss = loss.sum(axis=class_dim) / x.shape[class_dim]
return _get_loss(loss, reduction, cls_name)
def elu(input_x, alpha=1.0):
r"""
Exponential Linear Unit activation function.
@ -5177,6 +5239,7 @@ __all__ = [
'margin_ranking_loss',
'multi_margin_loss',
'multi_label_margin_loss',
'multilabel_soft_margin_loss',
'elu',
'gelu',
'hinge_embedding_loss',

View File

@ -0,0 +1,72 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class MultiLabelSoftMarginLossNet(nn.Cell):
def __init__(self, weight=None, reduction='mean'):
super(MultiLabelSoftMarginLossNet, self).__init__()
self.multilabel_soft_margin_loss = nn.MultiLabelSoftMarginLoss(weight=weight, reduction=reduction)
def construct(self, x, target):
return self.multilabel_soft_margin_loss(x, target)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('weight', [None, Tensor([1.0, 1.5, 0.8], mstype.float32)])
@pytest.mark.parametrize('reduction', ['mean', 'none', 'sum'])
def test_multilabel_soft_margin_loss(mode, weight, reduction):
"""
Feature: MultiLabelSoftMarginLoss with weight=[None, Tensor([1.0, 1.5, 0.8], mstype.float32)],
reduction=['mean', 'none', 'sum']
Description: Verify the result of MultiLabelSoftMarginLoss
Expectation: success
"""
context.set_context(mode=mode)
net = MultiLabelSoftMarginLossNet(weight=weight, reduction=reduction)
arr1 = np.array([[0.3, 0.6, 0.6], [0.9, 0.4, 0.2]], np.float32)
arr2 = np.array([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], np.float32)
x = Tensor(arr1, mstype.float32)
label = Tensor(arr2, mstype.float32)
output = net(x, label)
if weight is None:
if reduction == 'mean':
expected = np.array(0.846940, np.float32)
elif reduction == 'sum':
expected = np.array(1.693880, np.float32)
else:
expected = np.array([0.776444, 0.917436], np.float32)
else:
if reduction == 'mean':
expected = np.array(0.974961, np.float32)
elif reduction == 'sum':
expected = np.array(1.949922, np.float32)
else:
expected = np.array([0.920193, 1.029729], np.float32)
assert np.allclose(output.asnumpy(), expected)

View File

@ -0,0 +1,74 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context
class MultiLabelSoftMarginLossNet(nn.Cell):
def __init__(self, weight=None, reduction='mean'):
super(MultiLabelSoftMarginLossNet, self).__init__()
self.weight = weight
self.reduction = reduction
def construct(self, x, target):
return ops.multilabel_soft_margin_loss(x, target, self.weight, self.reduction)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('weight', [None, Tensor([1.0, 1.5, 0.8], mstype.float32)])
@pytest.mark.parametrize('reduction', ['mean', 'none', 'sum'])
def test_multilabel_soft_margin_loss(mode, weight, reduction):
"""
Feature: MultiLabelSoftMarginLoss with weight=[None, Tensor([1.0, 1.5, 0.8], mstype.float32)],
reduction=['mean', 'none', 'sum']
Description: Verify the result of MultiLabelSoftMarginLoss
Expectation: success
"""
context.set_context(mode=mode)
net = MultiLabelSoftMarginLossNet(weight=weight, reduction=reduction)
arr1 = np.array([[0.3, 0.6, 0.6], [0.9, 0.4, 0.2]], np.float32)
arr2 = np.array([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], np.float32)
x = Tensor(arr1, mstype.float32)
label = Tensor(arr2, mstype.float32)
output = net(x, label)
if weight is None:
if reduction == 'mean':
expected = np.array(0.846940, np.float32)
elif reduction == 'sum':
expected = np.array(1.693880, np.float32)
else:
expected = np.array([0.776444, 0.917436], np.float32)
else:
if reduction == 'mean':
expected = np.array(0.974961, np.float32)
elif reduction == 'sum':
expected = np.array(1.949922, np.float32)
else:
expected = np.array([0.920193, 1.029729], np.float32)
assert np.allclose(output.asnumpy(), expected)