!43911 Support MarginRankingLoss
Merge pull request !43911 from 冯一航/support_MarginRankingLoss
This commit is contained in:
commit
a60615e409
|
@ -233,6 +233,7 @@ Dropout层
|
|||
mindspore.nn.HuberLoss
|
||||
mindspore.nn.KLDivLoss
|
||||
mindspore.nn.L1Loss
|
||||
mindspore.nn.MarginRankingLoss
|
||||
mindspore.nn.MSELoss
|
||||
mindspore.nn.MultiClassDiceLoss
|
||||
mindspore.nn.NLLLoss
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
mindspore.nn.MarginRankingLoss
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.nn.MarginRankingLoss(margin=0.0, reduction='mean')
|
||||
|
||||
排序损失函数,用于创建一个衡量给定损失的标准。
|
||||
|
||||
给定两个Tensor :math:`x1` 和 :math:`x2` ,以及一个Tensor标签 :math:`y` ,值为1或-1,公式如下:
|
||||
|
||||
.. math::
|
||||
\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
|
||||
|
||||
参数:
|
||||
- **margin** (float) - 指定运算的调节因子。默认值:0.0。
|
||||
- **reduction** (str) - 指定输出结果的计算方式。可选值为"none"、"mean"或"sum",分别表示不指定计算方式、使用均值计算和使用求和计算。默认值:"mean"。
|
||||
|
||||
输入:
|
||||
- **input1** (Tensor) - 输入Tensor,shape :math:`(N,*)` ,其中 `*` 代表任意数量的附加维度。
|
||||
- **input2** (Tensor) - 输入Tensor,shape :math:`(N,*)` 。shape和数据类型与 `input1` 相同。
|
||||
- **target** (Tensor) - 输入值为1或-1。假设 `input1` 的shape是 :math:`(x_1, x_2, x_3, ..., x_R)` ,那么 `labels` 的shape必须是 :math:`(x_1, x_3, x_4, ..., x_R)` 。
|
||||
|
||||
输出:
|
||||
Tensor或Scalar,如果 `reduction` 为"none",其shape与 `labels` 相同。否则,将返回为Scalar。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `margin` 不是float。
|
||||
- **TypeError** - `input1` ,`input2` 和 `target` 不是Tensor。
|
||||
- **TypeError** - `input1` 和 `input2` 类型不一致。
|
||||
- **TypeError** - `input1` 和 `target` 类型不一致。
|
||||
- **ValueError** - `input1` 和 `input2` shape不一致。
|
||||
- **ValueError** - `input1` 和 `target` shape不一致。
|
||||
- **ValueError** - `reduction` 不为"none","mean"或"sum"。
|
|
@ -233,6 +233,7 @@ Loss Function
|
|||
mindspore.nn.HuberLoss
|
||||
mindspore.nn.KLDivLoss
|
||||
mindspore.nn.L1Loss
|
||||
mindspore.nn.MarginRankingLoss
|
||||
mindspore.nn.MSELoss
|
||||
mindspore.nn.MultiClassDiceLoss
|
||||
mindspore.nn.NLLLoss
|
||||
|
|
|
@ -20,13 +20,12 @@ It shows how well the model works on a dataset and the optimization target which
|
|||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
|
||||
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss, \
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss
|
||||
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss']
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss']
|
||||
|
|
|
@ -457,6 +457,86 @@ class MAELoss(LossBase):
|
|||
return self.get_loss(x)
|
||||
|
||||
|
||||
class MarginRankingLoss(LossBase):
|
||||
r"""
|
||||
MarginRankingLoss creates a criterion that measures the loss.
|
||||
|
||||
Given two tensors :math:`x1`, :math:`x2` and a Tensor label :math:`y` with values 1 or -1,
|
||||
the operation is as follows:
|
||||
|
||||
.. math::
|
||||
\text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
|
||||
|
||||
Args:
|
||||
margin (float): Specify the adjustment factor of the operation. Default 0.0.
|
||||
reduction (str): Specifies which reduction to be applied to the output. It must be one of
|
||||
"none", "mean", and "sum", meaning no reduction, reduce mean and sum on output, respectively. Default "mean".
|
||||
|
||||
Inputs:
|
||||
- **input1** (Tensor) - Tensor of shape :math:`(N, *)` where :math:`*` means, any number
|
||||
of additional dimensions.
|
||||
- **input2** (Tensor) - Tensor of shape :math:`(N, *)`, same shape and dtype as `input1`.
|
||||
- **target** (Tensor) - Contains value 1 or -1. Suppose the shape of `input1` is
|
||||
:math:`(x_1, x_2, x_3, ..., x_R)`, then the shape of `labels` must be :math:`(x_1, x_3, x_4, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar. if `reduction` is "none", its shape is the same as `labels`.
|
||||
Otherwise, a scalar value will be returned.
|
||||
|
||||
Raises:
|
||||
TypeError: If `margin` is not a float.
|
||||
TypeError: If `input1`, `input2` or `target` is not a Tensor.
|
||||
TypeError: If the types of `input1` and `input2` are inconsistent.
|
||||
TypeError: If the types of `input1` and `target` are inconsistent.
|
||||
ValueError: If the shape of `input1` and `input2` are inconsistent.
|
||||
ValueError: If the shape of `input1` and `target` are inconsistent.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore as ms
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as P
|
||||
>>> from mindspore.ops import Tensor
|
||||
>>> import numpy as np
|
||||
>>> loss1 = nn.MarginRankingLoss(reduction='none')
|
||||
>>> loss2 = nn.MarginRankingLoss(reduction='mean')
|
||||
>>> loss3 = nn.MarginRankingLoss(reduction='sum')
|
||||
>>> sign = P.Sign()
|
||||
>>> input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
|
||||
>>> input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
|
||||
>>> target = sign(Tensor(np.array([-2, -2, 3]), ms.float32))
|
||||
>>> output1 = loss1(input1, input2, target)
|
||||
>>> print(output1)
|
||||
[0.98759997 0. 2.7003999 ]
|
||||
>>> output2 = loss2(input1, input2, target)
|
||||
>>> print(output2)
|
||||
1.2293333
|
||||
>>> output3 = loss3(input1, input2, target)
|
||||
>>> print(output3)
|
||||
3.6879997
|
||||
"""
|
||||
|
||||
def __init__(self, margin=0.0, reduction='mean'):
|
||||
"""Initialize MarginRankingLoss."""
|
||||
super(MarginRankingLoss, self).__init__(reduction)
|
||||
self.margin = validator.check_value_type("margin", margin, [float], self.cls_name)
|
||||
self.reduction = reduction
|
||||
self.margin = margin
|
||||
self.maximum = P.Maximum()
|
||||
|
||||
def construct(self, input1, input2, target):
|
||||
_check_is_tensor('input1', input1, self.cls_name)
|
||||
_check_is_tensor('input2', input2, self.cls_name)
|
||||
_check_is_tensor('target', target, self.cls_name)
|
||||
F.same_type_shape(input1, input2)
|
||||
F.same_type_shape(target, input1)
|
||||
x = self.maximum(0, -target * (input1 - input2) + self.margin)
|
||||
return self.get_loss(x)
|
||||
|
||||
|
||||
class SmoothL1Loss(LossBase):
|
||||
r"""
|
||||
SmoothL1 loss function, if the absolute error element-wise between the predicted value and the target value
|
||||
|
@ -2150,6 +2230,7 @@ class KLDivLoss(LossBase):
|
|||
>>> print(output)
|
||||
-0.23333333
|
||||
"""
|
||||
|
||||
def __init__(self, reduction='mean'):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class MarginRankingLoss(nn.Cell):
|
||||
def __init__(self, reduction="none"):
|
||||
super(MarginRankingLoss, self).__init__()
|
||||
self.margin_ranking_loss = nn.MarginRankingLoss(reduction=reduction)
|
||||
|
||||
def construct(self, x, y, label):
|
||||
return self.margin_ranking_loss(x, y, label)
|
||||
|
||||
|
||||
input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
|
||||
input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
|
||||
target = Tensor(np.array([-1, -1, 1]), ms.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_margin_ranking_loss_none(mode):
|
||||
"""
|
||||
Feature: test MarginRankingLoss op with reduction none.
|
||||
Description: Verify the result of MarginRankingLoss.
|
||||
Expectation: expect correct forward result.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
loss = MarginRankingLoss('none')
|
||||
output = loss(input1, input2, target)
|
||||
expect_output = np.array([0.98759997, 0., 2.7003999])
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_margin_ranking_loss_sum(mode):
|
||||
"""
|
||||
Feature: test MarginRankingLoss op with reduction sum.
|
||||
Description: Verify the result of MarginRankingLoss.
|
||||
Expectation: expect correct forward result.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
loss = MarginRankingLoss('sum')
|
||||
output = loss(input1, input2, target)
|
||||
expect_output = np.array(3.6879997)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_margin_ranking_loss_mean(mode):
|
||||
"""
|
||||
Feature: test MarginRankingLoss op with reduction mean.
|
||||
Description: Verify the result of MarginRankingLoss.
|
||||
Expectation: expect correct forward result.
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
loss = MarginRankingLoss('mean')
|
||||
output = loss(input1, input2, target)
|
||||
expect_output = np.array(1.2293333)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
|
@ -15,6 +15,7 @@
|
|||
""" test loss """
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
|
@ -292,3 +293,16 @@ def test_nll_loss_4d():
|
|||
input_data = Tensor(np.random.randn(3, 5, 1, 1).astype(np.float32))
|
||||
target_data = Tensor(np.array([[[1]], [[0]], [[4]]]).astype(np.int32))
|
||||
loss(input_data, target_data)
|
||||
|
||||
|
||||
def test_margin_ranking_loss():
|
||||
"""
|
||||
Feature: Test MarginRankingLoss.
|
||||
Description: Test MarginRankingLoss functional.
|
||||
Expectation: Success.
|
||||
"""
|
||||
loss = nn.MarginRankingLoss()
|
||||
input1 = Tensor(np.array([0.3864, -2.4093, -1.4076]), ms.float32)
|
||||
input2 = Tensor(np.array([-0.6012, -1.6681, 1.2928]), ms.float32)
|
||||
target = Tensor(np.array([-1, -1, 1]), ms.float32)
|
||||
loss(input1, input2, target)
|
||||
|
|
Loading…
Reference in New Issue