forked from mindspore-Ecosystem/mindspore
Add nn.HingeEmbeddingLoss
This commit is contained in:
parent
79e56d6a84
commit
17e8e40ec5
|
@ -237,6 +237,7 @@ Dropout层
|
|||
mindspore.nn.DiceLoss
|
||||
mindspore.nn.FocalLoss
|
||||
mindspore.nn.GaussianNLLLoss
|
||||
mindspore.nn.HingeEmbeddingLoss
|
||||
mindspore.nn.HuberLoss
|
||||
mindspore.nn.KLDivLoss
|
||||
mindspore.nn.L1Loss
|
||||
|
|
|
@ -55,6 +55,7 @@ mindspore.ops.function
|
|||
mindspore.ops.binary_cross_entropy
|
||||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.nll_loss
|
||||
mindspore.ops.smooth_l1_loss
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ mindspore.nn.GaussianNLLLoss
|
|||
- **logits** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。`*` 代表着任意数量的额外维度。
|
||||
- **labels** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `logits` 具有相同shape,或者相同shape但有一个维度为1(以允许广播)。
|
||||
- **var** (Tensor) - shape为 :math:`(N, *)` 或 :math:`(*)`。和 `logits` 具有相同shape,或者相同shape但有一个维度为1,或者少一个维度(以允许广播)。
|
||||
|
||||
返回:
|
||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
mindspore.nn.HingeEmbeddingLoss
|
||||
===============================
|
||||
|
||||
.. py:class:: mindspore.nn.HingeEmbeddingLoss(margin=1.0, reduction="mean")
|
||||
|
||||
Hinge Embedding 损失函数。按输入元素计算输出。衡量输入张量x和标签y(包含1或-1)之间的损失值。通常被用来衡量两个输入之间的相似度。
|
||||
|
||||
mini-batch中的第n个样例的损失函数为:
|
||||
|
||||
.. math::
|
||||
l_n = \begin{cases}
|
||||
x_n, & \text{if}\; y_n = 1,\\
|
||||
\max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
|
||||
\end{cases}
|
||||
|
||||
总损失值为:
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean`;}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum`.}
|
||||
\end{cases}
|
||||
|
||||
其中 :math:`L = \{l_1,\dots,l_N\}^\top`。
|
||||
|
||||
参数:
|
||||
- **margin** (float) - Hinge Embedding Loss公式定义的阈值 :math:`margin`。公式中表示为 :math:`\Delta`。默认值:1.0。
|
||||
- **reduction** (str) - 指定应用于输出结果的计算方式,'none'、'mean'、'sum',默认值:'mean'。
|
||||
|
||||
输入:
|
||||
- **logits** (Tensor) - 预测值,公式中表示为 :math:`x`,shape为:math:`(*)`。`*` 代表着任意数量的维度。
|
||||
- **labels** (Tensor) - 标签值,公式中表示为 :math:`y`,和 `logits` 具有相同shape,包含1或-1。
|
||||
|
||||
返回:
|
||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `logits` 不是Tensor。
|
||||
- **TypeError** - `labels` 不是Tensor。
|
||||
- **TypeError** - `margin` 不是float。
|
||||
- **ValueError** - `labels` 和 `logits` shape不一致。
|
||||
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
|
@ -0,0 +1,40 @@
|
|||
mindspore.ops.hinge_embedding_loss
|
||||
===================================
|
||||
|
||||
.. py:function:: mindspore.ops.HingeEmbeddingLoss(inputs, targets, margin=1.0, reduction="mean")
|
||||
|
||||
Hinge Embedding 损失函数。按输入元素计算输出。衡量输入张量x和标签y(包含1或-1)之间的损失值。通常被用来衡量两个输入之间的相似度。
|
||||
|
||||
mini-batch中的第n个样例的损失函数为:
|
||||
|
||||
.. math::
|
||||
l_n = \begin{cases}
|
||||
x_n, & \text{if}\; y_n = 1,\\
|
||||
\max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
|
||||
\end{cases}
|
||||
|
||||
总损失值为:
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean`;}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum`.}
|
||||
\end{cases}
|
||||
|
||||
其中 :math:`L = \{l_1,\dots,l_N\}^\top`。
|
||||
|
||||
参数:
|
||||
- **inputs** (Tensor) - 预测值,公式中表示为 :math:`x`,shape为:math:`(*)`。`*` 代表着任意数量的维度。
|
||||
- **targets** (Tensor) - 标签值,公式中表示为 :math:`y`,和 `logits` 具有相同shape,包含1或-1。
|
||||
- **margin** (float) - Hinge Embedding Loss公式定义的阈值 :math:`margin`。公式中表示为:math:`\Delta`。默认值:1.0。
|
||||
- **reduction** (str) - 指定应用于输出结果的计算方式,'none'、'mean'、'sum',默认值:'mean'。
|
||||
|
||||
返回:
|
||||
Tensor或Tensor scalar,根据 :math:`reduction` 计算的loss。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `inputs` 不是Tensor。
|
||||
- **TypeError** - `targets` 不是Tensor。
|
||||
- **TypeError** - `margin` 不是float。
|
||||
- **ValueError** - `inputs` 和 `targets` shape不一致。
|
||||
- **ValueError** - `reduction` 不是"none"、"mean"或者"sum"。
|
|
@ -237,6 +237,7 @@ Loss Function
|
|||
mindspore.nn.DiceLoss
|
||||
mindspore.nn.FocalLoss
|
||||
mindspore.nn.GaussianNLLLoss
|
||||
mindspore.nn.HingeEmbeddingLoss
|
||||
mindspore.nn.HuberLoss
|
||||
mindspore.nn.KLDivLoss
|
||||
mindspore.nn.L1Loss
|
||||
|
|
|
@ -56,6 +56,7 @@ Loss Functions
|
|||
mindspore.ops.binary_cross_entropy
|
||||
mindspore.ops.binary_cross_entropy_with_logits
|
||||
mindspore.ops.cross_entropy
|
||||
mindspore.ops.hinge_embedding_loss
|
||||
mindspore.ops.nll_loss
|
||||
mindspore.ops.smooth_l1_loss
|
||||
|
||||
|
|
|
@ -22,11 +22,12 @@ from __future__ import absolute_import
|
|||
|
||||
from mindspore.nn.loss.loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss, \
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, MultilabelMarginLoss, \
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
|
||||
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss, \
|
||||
HingeEmbeddingLoss, MultilabelMarginLoss
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'MultiMarginLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss',
|
||||
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss',
|
||||
'GaussianNLLLoss']
|
||||
'GaussianNLLLoss', 'HingeEmbeddingLoss']
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
import math
|
||||
import mindspore
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
from mindspore import log
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
@ -2330,3 +2331,77 @@ class GaussianNLLLoss(LossBase):
|
|||
c = 0 if not self.full else 0.5 * math.log(2 * math.pi)
|
||||
loss = 0.5 * (logarithm + squared_loss / maxima) + c
|
||||
return self.get_loss(loss)
|
||||
|
||||
|
||||
class HingeEmbeddingLoss(LossBase):
|
||||
r"""
|
||||
Hinge Embedding Loss. Compute the output according to the input elements. Measures the loss given an input tensor x
|
||||
and a labels tensor y (containing 1 or -1).
|
||||
This is usually used for measuring the similarity between two inputs.
|
||||
|
||||
The loss function for :math:`n`-th sample in the mini-batch is
|
||||
|
||||
.. math::
|
||||
l_n = \begin{cases}
|
||||
x_n, & \text{if}\; y_n = 1,\\
|
||||
\max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
|
||||
\end{cases}
|
||||
|
||||
and the total loss functions is
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
where :math:`L = \{l_1,\dots,l_N\}^\top`.
|
||||
|
||||
Args:
|
||||
margin (float): Threshold defined by Hinge Embedding Loss :math:`margin`.
|
||||
Represented as :math:`\Delta` in the formula. Default: 1.0.
|
||||
reduction (string): Specify the computing method to be applied to the outputs: 'none', 'mean', or 'sum'.
|
||||
Default: 'mean'.
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape :math:`(*)` where :math:`*` means any number of dimensions.
|
||||
- **labels** (Tensor) - Same shape as the logits, contains -1 or 1.
|
||||
|
||||
Returns:
|
||||
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `logits` is not a Tensor.
|
||||
TypeError: If `labels` is not a Tensor.
|
||||
TypeError: If `margin` is not a float.
|
||||
ValueError: If `labels` does not have the same shape as `logits`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examplse:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
>>> arr2 = np.array([1, 1, -1, 1, -1, 1, -1, 1, 1]).reshape((3, 3))
|
||||
>>> logits = Tensor(arr1, mstype.float32)
|
||||
>>> labels = Tensor(arr2, mstype.float32)
|
||||
>>> loss = nn.HingeEmbeddingLoss(reduction='mean')
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
Tensor(shape=[], dtype=Float32, value= 1.6666667)
|
||||
"""
|
||||
|
||||
def __init__(self, margin=1.0, reduction='mean'):
|
||||
super(HingeEmbeddingLoss, self).__init__()
|
||||
validator.check_value_type('margin', margin, [float], self.cls_name)
|
||||
validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.cls_name)
|
||||
self.margin = margin
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, logits, labels):
|
||||
loss = ops.hinge_embedding_loss(logits, labels, self.margin, self.reduction)
|
||||
return loss
|
||||
|
|
|
@ -339,6 +339,7 @@ from .nn_func import (
|
|||
multi_label_margin_loss,
|
||||
elu,
|
||||
gelu,
|
||||
hinge_embedding_loss,
|
||||
)
|
||||
from .linalg_func import (
|
||||
svd,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
from __future__ import absolute_import
|
||||
from math import pi
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as NN_OPS
|
||||
|
@ -2459,6 +2460,90 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
|
|||
return (loss, log_alpha)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_hinge_embedding_loss(shape, shape2, prim_name):
|
||||
if shape2 != shape:
|
||||
raise ValueError(f"For '{prim_name}' the input tensor and the labels must have the same shape.")
|
||||
|
||||
|
||||
def hinge_embedding_loss(inputs, targets, margin=1.0, reduction='mean'):
|
||||
r"""
|
||||
Hinge Embedding Loss. Compute the output according to the input elements. Measures the loss given an input tensor x
|
||||
and a labels tensor y (containing 1 or -1).
|
||||
This is usually used for measuring the similarity between two inputs.
|
||||
|
||||
The loss function for :math:`n`-th sample in the mini-batch is
|
||||
|
||||
.. math::
|
||||
l_n = \begin{cases}
|
||||
x_n, & \text{if}\; y_n = 1,\\
|
||||
\max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
|
||||
\end{cases}
|
||||
|
||||
and the total loss functions is
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
where :math:`L = \{l_1,\dots,l_N\}^\top`.
|
||||
|
||||
Args:
|
||||
inputs (Tensor) - Tensor of shape :math:`(*)` where :math:`*` means any number of dimensions.
|
||||
targets (Tensor) - Same shape as the logits, contains -1 or 1.
|
||||
margin (float): Threshold defined by Hinge Embedding Loss :math:`margin`.
|
||||
Represented as :math:`\Delta` in the formula. Default: 1.0.
|
||||
reduction (string): Specify the computing method to be applied to the outputs: 'none', 'mean', or 'sum'.
|
||||
Default: 'mean'.
|
||||
|
||||
Returns:
|
||||
Tensor or Tensor scalar, the computed loss depending on `reduction`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `inputs` is not a Tensor.
|
||||
TypeError: If `targets` is not a Tensor.
|
||||
TypeError: If `margin` is not a float.
|
||||
ValueError: If `targets` does not have the same shape as `inputs`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examplse:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.common.dtype as mstype
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
>>> arr2 = np.array([1, 1, -1, 1, -1, 1, -1, 1, 1]).reshape((3, 3))
|
||||
>>> logits = Tensor(arr1, mstype.float32)
|
||||
>>> labels = Tensor(arr2, mstype.float32)
|
||||
>>> loss = ops.hinge_embedding_loss(logits, targets, margin=1.0, reduction='mean')
|
||||
>>> print(loss)
|
||||
Tensor(shape=[], dtype=Float32, value= 1.6666667)
|
||||
"""
|
||||
_shape = inputs.shape
|
||||
_dtype = inputs.dtype
|
||||
_t_shape = targets.shape
|
||||
_check_hinge_embedding_loss(_shape, _t_shape, 'HingeEmbeddingLoss')
|
||||
min_val = Tensor(0, _dtype)
|
||||
pos_index = targets > 0
|
||||
neg_index = targets < 0
|
||||
pos = pos_index * inputs
|
||||
neg = neg_index * inputs
|
||||
margin_matrix = margin * neg_index
|
||||
neg = margin_matrix - neg
|
||||
neg = ops.clip_by_value(neg, min_val)
|
||||
loss = pos + neg
|
||||
if reduction == 'mean':
|
||||
loss = loss.mean()
|
||||
elif reduction == 'sum':
|
||||
loss = loss.sum()
|
||||
return loss
|
||||
|
||||
|
||||
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
r"""
|
||||
Performs greedy decoding on the logits given in inputs.
|
||||
|
@ -3466,5 +3551,6 @@ __all__ = [
|
|||
'multi_label_margin_loss',
|
||||
'elu',
|
||||
'gelu',
|
||||
'hinge_embedding_loss'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(Net, self).__init__()
|
||||
self.loss = nn.HingeEmbeddingLoss(margin=1.0, reduction=reduction)
|
||||
|
||||
def construct(self, x, label):
|
||||
loss = self.loss(x, label)
|
||||
return loss
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('reduction', ['mean', 'sum', 'none'])
|
||||
def test_hinge_embedding_loss(mode, reduction):
|
||||
"""
|
||||
Feature: HingeEmbeddingLoss with margin=1.0
|
||||
Description: Verify the result of HingeEmbeddingLoss
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(reduction=reduction)
|
||||
arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
arr2 = np.array([1, 1, -1, 1, -1, 1, -1, 1, 1]).reshape((3, 3))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
output = net(a, b)
|
||||
|
||||
if reduction == 'mean':
|
||||
expected = np.array(1 / 6, np.float32)
|
||||
elif reduction == 'sum':
|
||||
expected = np.array(1.5, np.float32)
|
||||
else:
|
||||
expected = np.array([[0.9000, -1.2000, 0.0000],
|
||||
[0.8000, 0.0000, 2.0000],
|
||||
[0.0000, 0.0000, -1.0000]], np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
|
@ -0,0 +1,51 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(Net, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, x, label):
|
||||
loss = ops.hinge_embedding_loss(x, label, reduction=self.reduction)
|
||||
return loss
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('reduction', ['mean', 'sum', 'none'])
|
||||
def test_hinge_embedding_loss(mode, reduction):
|
||||
"""
|
||||
Feature: HingeEmbeddingLoss with margin=1.0
|
||||
Description: Verify the result of HingeEmbeddingLoss
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(reduction=reduction)
|
||||
arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
arr2 = np.array([1, 1, -1, 1, -1, 1, -1, 1, 1]).reshape((3, 3))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
output = net(a, b)
|
||||
|
||||
if reduction == 'mean':
|
||||
expected = np.array(1 / 6, np.float32)
|
||||
elif reduction == 'sum':
|
||||
expected = np.array(1.5, np.float32)
|
||||
else:
|
||||
expected = np.array([[0.9000, -1.2000, 0.0000],
|
||||
[0.8000, 0.0000, 2.0000],
|
||||
[0.0000, 0.0000, -1.0000]], np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
|
@ -0,0 +1,34 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(Net, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_func = nn.HingeEmbeddingLoss(reduction=self.reduction)
|
||||
|
||||
def construct(self, x, label):
|
||||
loss = self.loss_func(x, label)
|
||||
return loss
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_hinge_embedding_loss_abnormal(mode):
|
||||
"""
|
||||
Feature: HingeEmbeddingLoss
|
||||
Description: Verify abnormal inputs of HingeEmbeddingLoss
|
||||
Expectation: raise ValueError
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(reduction='mean')
|
||||
arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
arr2 = np.array([1, 1, -1, 1]).reshape((2, 2))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(a, b)
|
|
@ -0,0 +1,34 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(Net, self).__init__()
|
||||
self.reduction = reduction
|
||||
|
||||
def construct(self, x, label):
|
||||
loss = ops.hinge_embedding_loss(x, label, reduction=self.reduction)
|
||||
return loss
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_hinge_embedding_loss_abnormal(mode):
|
||||
"""
|
||||
Feature: HingeEmbeddingLoss
|
||||
Description: Verify abnormal inputs of HingeEmbeddingLoss
|
||||
Expectation: raise ValueError
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net(reduction='mean')
|
||||
arr1 = np.array([0.9, -1.2, 2, 0.8, 3.9, 2, 1, 0, -1]).reshape((3, 3))
|
||||
arr2 = np.array([1, 1, -1, 1]).reshape((2, 2))
|
||||
a = Tensor(arr1, mstype.float32)
|
||||
b = Tensor(arr2, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(a, b)
|
Loading…
Reference in New Issue