!43911 Support MarginRankingLoss

Merge pull request !43911 from 冯一航/support_MarginRankingLoss
This commit is contained in:
i-robot 2022-10-18 11:45:56 +00:00 committed by Gitee
commit a60615e409
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 231 additions and 5 deletions

View File

@ -233,6 +233,7 @@ Dropout层
mindspore.nn.HuberLoss
mindspore.nn.KLDivLoss
mindspore.nn.L1Loss
mindspore.nn.MarginRankingLoss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.NLLLoss

View File

@ -0,0 +1,32 @@
mindspore.nn.MarginRankingLoss
===============================
.. py:class:: mindspore.nn.MarginRankingLoss(margin=0.0, reduction='mean')
排序损失函数,用于创建一个衡量给定损失的标准。
给定两个Tensor :math:`x1`:math:`x2` 以及一个Tensor标签 :math:`y` 值为1或-1公式如下
.. math::
\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
参数:
- **margin** (float) - 指定运算的调节因子。默认值0.0。
- **reduction** (str) - 指定输出结果的计算方式。可选值为"none"、"mean"或"sum",分别表示不指定计算方式、使用均值计算和使用求和计算。默认值:"mean"。
输入:
- **input1** (Tensor) - 输入Tensorshape :math:`(N,*)` ,其中 `*` 代表任意数量的附加维度。
- **input2** (Tensor) - 输入Tensorshape :math:`(N,*)` 。shape和数据类型与 `input1` 相同。
- **target** (Tensor) - 输入值为1或-1。假设 `input1` 的shape是 :math:`(x_1, x_2, x_3, ..., x_R)` ,那么 `labels` 的shape必须是 :math:`(x_1, x_3, x_4, ..., x_R)`
输出:
Tensor或Scalar如果 `reduction` 为"none"其shape与 `labels` 相同。否则将返回为Scalar。
异常:
- **TypeError** - `margin` 不是float。
- **TypeError** - `input1` `input2``target` 不是Tensor。
- **TypeError** - `input1``input2` 类型不一致。
- **TypeError** - `input1``target` 类型不一致。
- **ValueError** - `input1``input2` shape不一致。
- **ValueError** - `input1``target` shape不一致。
- **ValueError** - `reduction` 不为"none""mean"或"sum"。

View File

@ -233,6 +233,7 @@ Loss Function
mindspore.nn.HuberLoss
mindspore.nn.KLDivLoss
mindspore.nn.L1Loss
mindspore.nn.MarginRankingLoss
mindspore.nn.MSELoss
mindspore.nn.MultiClassDiceLoss
mindspore.nn.NLLLoss

View File

@ -20,13 +20,12 @@ It shows how well the model works on a dataset and the optimization target which
"""
from __future__ import absolute_import
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss, \
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss']
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss']

View File

@ -457,6 +457,86 @@ class MAELoss(LossBase):
return self.get_loss(x)
class MarginRankingLoss(LossBase):
r"""
MarginRankingLoss creates a criterion that measures the loss.
Given two tensors :math:`x1`, :math:`x2` and a Tensor label :math:`y` with values 1 or -1,
the operation is as follows:
.. math::
\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
Args:
margin (float): Specify the adjustment factor of the operation. Default 0.0.
reduction (str): Specifies which reduction to be applied to the output. It must be one of
"none", "mean", and "sum", meaning no reduction, reduce mean and sum on output, respectively. Default "mean".
Inputs:
- **input1** (Tensor) - Tensor of shape :math:`(N, *)` where :math:`*` means, any number
of additional dimensions.
- **input2** (Tensor) - Tensor of shape :math:`(N, *)`, same shape and dtype as `input1`.
- **target** (Tensor) - Contains value 1 or -1. Suppose the shape of `input1` is
:math:`(x_1, x_2, x_3, ..., x_R)`, then the shape of `labels` must be :math:`(x_1, x_3, x_4, ..., x_R)`.
Outputs:
Tensor or Scalar. if `reduction` is "none", its shape is the same as `labels`.
Otherwise, a scalar value will be returned.
Raises:
TypeError: If `margin` is not a float.
TypeError: If `input1`, `input2` or `target` is not a Tensor.
TypeError: If the types of `input1` and `input2` are inconsistent.
TypeError: If the types of `input1` and `target` are inconsistent.
ValueError: If the shape of `input1` and `input2` are inconsistent.
ValueError: If the shape of `input1` and `target` are inconsistent.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> import mindspore.ops as P
>>> from mindspore.ops import Tensor
>>> import numpy as np
>>> loss1 = nn.MarginRankingLoss(reduction='none')
>>> loss2 = nn.MarginRankingLoss(reduction='mean')
>>> loss3 = nn.MarginRankingLoss(reduction='sum')
>>> sign = P.Sign()
>>> input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
>>> input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
>>> target = sign(Tensor(np.array([-2, -2, 3]), ms.float32))
>>> output1 = loss1(input1, input2, target)
>>> print(output1)
[0.98759997 0. 2.7003999 ]
>>> output2 = loss2(input1, input2, target)
>>> print(output2)
1.2293333
>>> output3 = loss3(input1, input2, target)
>>> print(output3)
3.6879997
"""
def __init__(self, margin=0.0, reduction='mean'):
"""Initialize MarginRankingLoss."""
super(MarginRankingLoss, self).__init__(reduction)
self.margin = validator.check_value_type("margin", margin, [float], self.cls_name)
self.reduction = reduction
self.margin = margin
self.maximum = P.Maximum()
def construct(self, input1, input2, target):
_check_is_tensor('input1', input1, self.cls_name)
_check_is_tensor('input2', input2, self.cls_name)
_check_is_tensor('target', target, self.cls_name)
F.same_type_shape(input1, input2)
F.same_type_shape(target, input1)
x = self.maximum(0, -target * (input1 - input2) + self.margin)
return self.get_loss(x)
class SmoothL1Loss(LossBase):
r"""
SmoothL1 loss function, if the absolute error element-wise between the predicted value and the target value
@ -2150,6 +2230,7 @@ class KLDivLoss(LossBase):
>>> print(output)
-0.23333333
"""
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction

View File

@ -0,0 +1,98 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
class MarginRankingLoss(nn.Cell):
def __init__(self, reduction="none"):
super(MarginRankingLoss, self).__init__()
self.margin_ranking_loss = nn.MarginRankingLoss(reduction=reduction)
def construct(self, x, y, label):
return self.margin_ranking_loss(x, y, label)
input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
target = Tensor(np.array([-1, -1, 1]), ms.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_margin_ranking_loss_none(mode):
"""
Feature: test MarginRankingLoss op with reduction none.
Description: Verify the result of MarginRankingLoss.
Expectation: expect correct forward result.
"""
ms.set_context(mode=mode)
loss = MarginRankingLoss('none')
output = loss(input1, input2, target)
expect_output = np.array([0.98759997, 0., 2.7003999])
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_margin_ranking_loss_sum(mode):
"""
Feature: test MarginRankingLoss op with reduction sum.
Description: Verify the result of MarginRankingLoss.
Expectation: expect correct forward result.
"""
ms.set_context(mode=mode)
loss = MarginRankingLoss('sum')
output = loss(input1, input2, target)
expect_output = np.array(3.6879997)
assert np.allclose(output.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_margin_ranking_loss_mean(mode):
"""
Feature: test MarginRankingLoss op with reduction mean.
Description: Verify the result of MarginRankingLoss.
Expectation: expect correct forward result.
"""
ms.set_context(mode=mode)
loss = MarginRankingLoss('mean')
output = loss(input1, input2, target)
expect_output = np.array(1.2293333)
assert np.allclose(output.asnumpy(), expect_output)

View File

@ -15,6 +15,7 @@
""" test loss """
import numpy as np
import pytest
import mindspore as ms
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -292,3 +293,16 @@ def test_nll_loss_4d():
input_data = Tensor(np.random.randn(3, 5, 1, 1).astype(np.float32))
target_data = Tensor(np.array([[[1]], [[0]], [[4]]]).astype(np.int32))
loss(input_data, target_data)
def test_margin_ranking_loss():
"""
Feature: Test MarginRankingLoss.
Description: Test MarginRankingLoss functional.
Expectation: Success.
"""
loss = nn.MarginRankingLoss()
input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
target = Tensor(np.array([-1, -1, 1]), ms.float32)
loss(input1, input2, target)