!47654 add api: ops.triplet_margin_loss and nn.TripletMarginLoss

Merge pull request !47654 from GuoZhibin/Add_TripletMarginLoss
This commit is contained in:
i-robot 2023-01-13 07:39:27 +00:00 committed by Gitee
commit faf732e48b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 430 additions and 25 deletions

View File

@ -256,6 +256,7 @@ Dropout层
mindspore.nn.SmoothL1Loss
mindspore.nn.SoftMarginLoss
mindspore.nn.SoftmaxCrossEntropyWithLogits
mindspore.nn.TripletMarginLoss
优化器
-------

View File

@ -75,6 +75,7 @@ mindspore.ops
mindspore.ops.mse_loss
mindspore.ops.nll_loss
mindspore.ops.smooth_l1_loss
mindspore.ops.triplet_margin_loss
激活函数
^^^^^^^^^^

View File

@ -0,0 +1,49 @@
mindspore.nn.TripletMarginLoss
===============================
.. py:class:: class TripletMarginLoss(p=2, swap=False, eps=1e-06, reduction='mean')
执行三元组损失函数的操作。
创建一个标准用于计算输入Tensor :math:`x`:math:`positive`:math:`negative` 与大于 :math:`0``margin` 之间的三元组损失值。
可以用来测量样本之间的相似度。一个三元组包含 `a``p``n` (即分别代表 `x``positive``negative` )。
所有输入Tensor的shape都应该为 :math:`(N, D)`
距离交换在V. Balntas、E. Riba等人在论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks <http://158.109.8.37/files/BRP2016.pdf>`_ 中有详细的阐述。
对于每个小批量样本,损失值为:
.. math::
L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
其中
.. math::
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
参数:
- **p** (int可选) - 成对距离的范数。默认值2。
- **swap** (bool可选) - 距离交换。默认值False。
- **eps** (float可选) - 防止除数为 0。默认值1e-06。
- **reduction** (str可选) - 指定要应用于输出的缩减方式,取值为"mean"、"sum"或"none"。默认值:"mean"。
输入:
- **x** (Tensor) - 从训练集随机选取的样本。数据类型为BasicType。
- **positive** (Tensor) - 与 `x` 为同一类的样本数据类型与shape与 `x` 一致。
- **negative** (Tensor) - 与 `x` 为异类的样本数据类型与shape与 `x` 一致。
- **margin** (Tensor) - 用于拉进 `x``positive` 之间的距离,拉远 `x``negative` 之间的距离。
输出:
Tensor。如果 `reduction` 为"none"其shape为 :math:`(N)`。否则将返回Scalar。
异常:
- **TypeError** - `x``positive``negative` 或者 `margin` 不是Tensor。
- **TypeError** - `x``positive` 或者 `negative` 的数据类型不一致。
- **TypeError** - `margin` 的数据类型不是float32。
- **TypeError** - `p` 的数据类型不是int。
- **TypeError** - `eps` 的数据类型不是float。
- **TypeError** - `swap` 的数据类型不是bool。
- **ValueError** - `x``positive``negative` 的维度同时小于等于1。
- **ValueError** - `x``positive``negative` 的维度大于等于8。
- **ValueError** - `margin` 的shape长度不为0。
- **ValueError** - `x``positive``negative` 三者之间的shape无法广播。
- **ValueError** - `reduction` 不为"mean"、"sum"或"none"。

View File

@ -0,0 +1,32 @@
mindspore.ops.triplet_margin_loss
==================================
.. py:function:: mindspore.ops.triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, reduction='mean')
三元组损失函数。
详情请查看 :class:`mindspore.nn.TripletMarginLoss`
参数:
- **anchor** (Tensor) - 从训练集随机选取的样本。数据类型为BasicType。
- **positive** (Tensor) - 与 `anchor` 为同一类的样本数据类型与shape与 `anchor` 一致。
- **negative** (Tensor) - 与 `anchor` 为异类的样本数据类型与shape与 `anchor` 一致。
- **margin** (float可选) - 用于拉进 `anchor``positive` 之间的距离,拉远 `anchor``negative` 之间的距离。默认值1.0。
- **p** (int可选) - 成对距离的范数。默认值2。
- **eps** (float可选) - 防止除数为 0。默认值1e-06。
- **swap** (bool可选) - 距离交换。默认值False。
- **reduction** (str可选) - 指定要应用于输出的缩减方式,取值为"mean"、"sum"或"none"。默认值:"mean"。
返回:
Tensor。如果 `reduction` 为"none"其shape为 :math:`(N)`。否则将返回Scalar。
异常:
- **TypeError** - `anchor``positive` 或者 `negative` 不是Tensor。
- **TypeError** - `anchor``positive` 或者 `negative` 的数据类型不一致。
- **TypeError** - `margin` 的数据类型不是float。
- **TypeError** - `p` 的数据类型不是int。
- **TypeError** - `eps` 的数据类型不是float。
- **TypeError** - `swap` 的数据类型不是bool。
- **ValueError** - `anchor``positive``negative` 的维度同时小于等于1。
- **ValueError** - `anchor``positive``negative` 的维度大于等于8。
- **ValueError** - `anchor``positive``negative` 三者之间的shape无法广播。
- **ValueError** - `reduction` 不为"mean"、"sum"或"none"。

View File

@ -256,6 +256,7 @@ Loss Function
mindspore.nn.SmoothL1Loss
mindspore.nn.SoftMarginLoss
mindspore.nn.SoftmaxCrossEntropyWithLogits
mindspore.nn.TripletMarginLoss
Optimizer
---------

View File

@ -76,6 +76,7 @@ Loss Functions
mindspore.ops.mse_loss
mindspore.ops.nll_loss
mindspore.ops.smooth_l1_loss
mindspore.ops.triplet_margin_loss
Activation Functions
^^^^^^^^^^^^^^^^^^^^

View File

@ -24,7 +24,7 @@ from mindspore.nn.loss.loss import LossBase, L1Loss, CTCLoss, MSELoss, SmoothL1L
SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, PoissonNLLLoss, MultiLabelSoftMarginLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss, \
HingeEmbeddingLoss, MultilabelMarginLoss
HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss
__all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
@ -32,4 +32,4 @@ __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarg
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'PoissonNLLLoss',
'MultiLabelSoftMarginLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss',
'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss',
'GaussianNLLLoss', 'HingeEmbeddingLoss']
'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss']

View File

@ -26,7 +26,6 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp
from mindspore.ops import functional as F
from mindspore import nn
from mindspore.ops.primitive import constexpr
@ -1880,15 +1879,16 @@ class TripletMarginLoss(LossBase):
TripletMarginLoss operation.
Creates a criterion that measures the triplet loss given an input
tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples. A triplet
is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
examples` respectively). The shapes of all input tensors should be
tensors :math:`x`, :math:`positive`, :math:`negative` and a :math:`margin` with a value greater than :math:`0`.
This is used for measuring a relative similarity between samples.
A triplet is composed by `a`, `p` and `n` (i.e., `x`, `positive` and `negative` respectively).
The shapes of all input tensors should be
:math:`(N, D)`.
The distance swap is described in detail in the paper `Learning shallow
convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al.
The distance swap is described in detail in the paper
`Learning local feature descriptors with triplets and shallow convolutional neural
networks <http://158.109.8.37/files/BRP2016.pdf>`_
by V. Balntas, E. Riba et al.
The loss function for each sample in the mini-batch is:
@ -1901,26 +1901,25 @@ class TripletMarginLoss(LossBase):
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
Args:
p (int): The norm degree for pairwise distance. Default: 2.
eps (float): Default: 1e-06.
swap (bool): The distance swap is described in detail in the paper
`Learning shallow convolutional feature descriptors with triplet losses` by
V. Balntas, E. Riba et al. Default: "False".
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
p (int, optional): The norm degree for pairwise distance. Default: 2.
eps (float, optional): Add small value to avoid division by zero. Default: 1e-06.
swap (bool, optional): The distance swap change the negative distance to the distance between positive
sample and negative sample. Default: "False".
reduction (str, optional): Apply specific reduction method to the output: 'none', 'mean', 'sum'.
Default: "mean".
Inputs:
- **x** (Tensor) - A sample randomly selected from the training set. Data type must be BasicType.
- **positive** (Tensor) - A sample belonging to the same category as x, with the same type and shape as `x`.
- **negative** (Tensor) - A sample belonging to the different class from x, with the same type and shape as `x`.
- **positive** (Tensor) - A sample belonging to the same category as `x`, with the same type and shape as `x`.
- **negative** (Tensor) - A sample belonging to the different class from `x`, with the same type and shape
as `x`.
- **margin** (Tensor) - Make a margin between the positive pair and the negative pair.
Outputs:
Union[Tensor, Scalar], if `reduction` is "none", its shape is :math:`(N)`.
Otherwise, a scalar value will be returned.
Tensor. If `reduction` is "none", its shape is :math:`(N)`. Otherwise, a scalar value will be returned.
Raises:
TypeError: If `x` or `positive` or 'negative' or 'margin' is not a Tensor.
TypeError: If dtype of `x` or `positive` or `negative` is not BasicType.
TypeError: If dtype of `x`, `positive` and `negative` is not the same.
TypeError: If `margin` is not float32.
TypeError: If `p` is not an int.
@ -1933,7 +1932,7 @@ class TripletMarginLoss(LossBase):
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU`
``GPU``
Examples:
>>> loss = nn.TripletMarginLoss()
@ -1946,12 +1945,16 @@ class TripletMarginLoss(LossBase):
0.8881968
"""
def __init__(self, p=2, swap=False, eps=1e-6, reduction='mean'):
def __init__(self, p=2, swap=False, eps=1e-06, reduction='mean'):
super(TripletMarginLoss, self).__init__()
self.triplet_margin_loss = TripletMarginLossOp(p=p, swap=swap, eps=eps, reduction=reduction)
self.p = p
self.swap = swap
self.eps = eps
self.reduction = reduction
def construct(self, x, positive, negative, margin):
return self.triplet_margin_loss(x, positive, negative, margin)
return F.triplet_margin_loss(x, positive, negative, margin=margin, p=self.p,
eps=self.eps, swap=self.swap, reduction=self.reduction)
@constexpr

View File

@ -465,6 +465,7 @@ from .nn_func import (
lp_pool1d,
lp_pool2d,
mse_loss,
triplet_margin_loss,
msort
)
from .linalg_func import (

View File

@ -35,6 +35,7 @@ from mindspore.ops.operations.nn_ops import MaxUnpool2D, MaxUnpool3D
from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize
from mindspore.ops.operations.nn_ops import PadV3
from mindspore.ops.operations.nn_ops import ChannelShuffle
from mindspore.ops.operations.nn_ops import TripletMarginLoss
slice_ = P.Slice()
fast_gelu_ = P.FastGeLU()
@ -5611,6 +5612,58 @@ def msort(x):
return ops.Sort(axis=0)(x)[0]
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, reduction='mean'):
"""
TripletMarginLoss operation.
See :class:`mindspore.nn.TripletMarginLoss` for details.
Args:
anchor (Tensor): A sample randomly selected from the training set. Data type must be BasicType.
positive (Tensor): A sample belonging to the same category as `anchor`, with the same type and shape
as `anchor`.
negative (Tensor): A sample belonging to the different class from `anchor`, with the same type and shape
as `anchor`.
margin (float, optional): Make a margin between the positive pair and the negative pair. Default: 1.0.
p (int, optional): The norm degree for pairwise distance. Default: 2.
eps (float, optional): Add small value to avoid division by zero. Default: 1e-06.
swap (bool, optional): The distance swap change the negative distance to the distance between positive
sample and negative sample. Default: "False".
reduction (str, optional): Apply specific reduction method to the output: 'none', 'mean', 'sum'.
Default: "mean".
Returns:
Tensor. If `reduction` is "none", its shape is :math:`(N)`. Otherwise, a scalar value will be returned.
Raises:
TypeError: If `anchor` or `positive` or 'negative' is not a Tensor.
TypeError: If dtype of `anchor`, `positive` and `negative` is not the same.
TypeError: If `margin` is not a float.
TypeError: If `p` is not an int.
TypeError: If `eps` is not a float.
TypeError: If `swap` is not a bool.
ValueError: If dimensions of input `anchor`, `positive` and `negative` are less than or equal to 1 at the
same time.
ValueError: If the dimension of input `anchor` or `positive` or `negative` is bigger than or equal to 8.
ValueError: If shape of `anchor`, `positive` and `negative` cannot broadcast.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``GPU``
Examples:
>>> anchor = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
>>> positive = Tensor(np.array([[0.4, 0.6], [0.4, 0.6]]), mindspore.float32)
>>> negative = Tensor(np.array([[0.2, 0.9], [0.3, 0.7]]), mindspore.float32)
>>> output = ops.triplet_margin_loss(anchor, positive, negative)
>>> print(output)
0.8881968
"""
if not isinstance(margin, Tensor):
margin = Tensor(margin, mstype.float32)
triplet_margin_loss_op = _get_cache_prim(TripletMarginLoss)(p=p, eps=eps, swap=swap, reduction=reduction)
return triplet_margin_loss_op(anchor, positive, negative, margin)
__all__ = [
'adaptive_avg_pool1d',
'adaptive_avg_pool2d',
@ -5695,5 +5748,6 @@ __all__ = [
'max_unpool3d',
'mse_loss',
'msort',
'triplet_margin_loss',
]
__all__.sort()

View File

@ -0,0 +1,123 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_triplet_margin_loss_float64(mode):
"""
Feature: Input type of float64
Description: Input type of [float64, float64, float64].
Expectation: success.
"""
context.set_context(mode=mode)
data_type = np.float64
anchor_array = np.array([[1.3, 20.5, 5.6],
[3.5, 4.8, 7.2],
[0.2, 0.01, 1],
[4, 4.1, 20]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = np.float32(2.0)
p = 0
swap = True
reduction = "none"
eps = 1e-5
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
ms_margin = Tensor(margin)
triplet_margin_loss = nn.TripletMarginLoss(p=p, eps=eps, swap=swap, reduction=reduction)
output_ms = triplet_margin_loss(anchor, positive, negative, ms_margin)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive,
torch_negative, margin=margin,
p=p, eps=eps, swap=swap,
reduction=reduction)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-4,
atol=1e-4,
equal_nan=False)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_triplet_margin_loss_float32(mode):
"""
Feature: Input type of float32
Description: Input type of [float32, float32, float32].
Expectation: success.
"""
context.set_context(mode=mode)
data_type = np.float32
anchor_array = np.array([[1.3, 20.5, 5.6],
[3.5, 4.8, 7.2],
[0.2, 0.01, 1],
[4, 4.1, 20]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = np.float32(2.0)
p = 1
swap = False
reduction = "none"
eps = 1e-6
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
ms_margin = Tensor(margin)
triplet_margin_loss = nn.TripletMarginLoss(p=p, eps=eps, swap=swap, reduction=reduction)
output_ms = triplet_margin_loss(anchor, positive, negative, ms_margin)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive,
torch_negative, margin=margin,
p=p, eps=eps, swap=swap,
reduction=reduction)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-4,
atol=1e-4,
equal_nan=False)

View File

@ -0,0 +1,139 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
import numpy as np
import pytest
import mindspore as ms
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.common.dtype as mstype
class NetTripletMarginLoss(nn.Cell):
def __init__(self, margin=Tensor(1.0, mstype.float32), p=2, swap=False, eps=1e-6, reduction="mean"):
super(NetTripletMarginLoss, self).__init__()
self.margin = margin
self.p = p
self.swap = swap
self.eps = eps
self.reduction = reduction
def construct(self, anchor, positive, negative):
return ops.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
eps=self.eps, swap=self.swap, reduction=self.reduction)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_triplet_margin_loss_float64(mode):
"""
Feature: Input type of float64
Description: Input type of [float64, float64, float64].
Expectation: success.
"""
context.set_context(mode=mode)
data_type = np.float64
anchor_array = np.array([[1.3, 20.5, 5.6],
[3.5, 4.8, 7.2],
[0.2, 0.01, 1],
[4, 4.1, 20]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = np.float32(2.0)
p = 0
swap = True
reduction = "none"
eps = 1e-5
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
triplet_margin_loss = NetTripletMarginLoss(margin=margin, p=p, eps=eps,
swap=swap, reduction=reduction)
output_ms = triplet_margin_loss(anchor, positive, negative)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive,
torch_negative, margin=margin,
p=p, eps=eps, swap=swap,
reduction=reduction)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-4,
atol=1e-4,
equal_nan=False)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_triplet_margin_loss_float32(mode):
"""
Feature: Input type of float32
Description: Input type of [float32, float32, float32].
Expectation: success.
"""
context.set_context(mode=mode)
data_type = np.float32
anchor_array = np.array([[1.3, 20.5, 5.6],
[3.5, 4.8, 7.2],
[0.2, 0.01, 1],
[4, 4.1, 20]]).astype(data_type)
positive_array = np.array([[2., 10., 1.],
[6., 7., 10.],
[13., 4., 1.],
[0.33, -4, -1.5]]).astype(data_type)
negative_array = np.array([[2., 21., 6.],
[68., 9., 10.],
[131., 25., 16.],
[0.31, -0.14, -16.]]).astype(data_type)
margin = np.float32(2.0)
p = 1
swap = False
reduction = "none"
eps = 1e-6
anchor = Tensor(anchor_array)
positive = Tensor(positive_array)
negative = Tensor(negative_array)
triplet_margin_loss = NetTripletMarginLoss(margin=margin, p=p, eps=eps,
swap=swap, reduction=reduction)
output_ms = triplet_margin_loss(anchor, positive, negative)
torch_anchor = torch.tensor(anchor_array)
torch_positive = torch.tensor(positive_array)
torch_negative = torch.tensor(negative_array)
expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive,
torch_negative, margin=margin,
p=p, eps=eps, swap=swap,
reduction=reduction)
assert np.allclose(output_ms.asnumpy(),
expect.numpy(),
rtol=1e-4,
atol=1e-4,
equal_nan=False)