diff --git a/docs/api/api_python/mindspore.nn.rst b/docs/api/api_python/mindspore.nn.rst index ca82d798196..e876adad620 100644 --- a/docs/api/api_python/mindspore.nn.rst +++ b/docs/api/api_python/mindspore.nn.rst @@ -256,6 +256,7 @@ Dropout层 mindspore.nn.SmoothL1Loss mindspore.nn.SoftMarginLoss mindspore.nn.SoftmaxCrossEntropyWithLogits + mindspore.nn.TripletMarginLoss 优化器 ------- diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index 9a915b1309c..57f70efa202 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -75,6 +75,7 @@ mindspore.ops mindspore.ops.mse_loss mindspore.ops.nll_loss mindspore.ops.smooth_l1_loss + mindspore.ops.triplet_margin_loss 激活函数 ^^^^^^^^^^ diff --git a/docs/api/api_python/nn/mindspore.nn.TripletMarginLoss.rst b/docs/api/api_python/nn/mindspore.nn.TripletMarginLoss.rst new file mode 100644 index 00000000000..91d15dde426 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.TripletMarginLoss.rst @@ -0,0 +1,49 @@ +mindspore.nn.TripletMarginLoss +=============================== + +.. py:class:: class TripletMarginLoss(p=2, swap=False, eps=1e-06, reduction='mean') + + 执行三元组损失函数的操作。 + + 创建一个标准,用于计算输入Tensor :math:`x` 、 :math:`positive` 和 :math:`negative` 与大于 :math:`0` 的 `margin` 之间的三元组损失值。 + 可以用来测量样本之间的相似度。一个三元组包含 `a` 、 `p` 和 `n` (即分别代表 `x` 、 `positive` 和 `negative` )。 + 所有输入Tensor的shape都应该为 :math:`(N, D)` 。 + 距离交换在V. Balntas、E. Riba等人在论文 `Learning local feature descriptors with triplets and shallow convolutional neural networks `_ 中有详细的阐述。 + + 对于每个小批量样本,损失值为: + + .. math:: + L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + 其中 + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + 参数: + - **p** (int,可选) - 成对距离的范数。默认值:2。 + - **swap** (bool,可选) - 距离交换。默认值:False。 + - **eps** (float,可选) - 防止除数为 0。默认值:1e-06。 + - **reduction** (str,可选) - 指定要应用于输出的缩减方式,取值为"mean"、"sum"或"none"。默认值:"mean"。 + + 输入: + - **x** (Tensor) - 从训练集随机选取的样本。数据类型为BasicType。 + - **positive** (Tensor) - 与 `x` 为同一类的样本,数据类型与shape与 `x` 一致。 + - **negative** (Tensor) - 与 `x` 为异类的样本,数据类型与shape与 `x` 一致。 + - **margin** (Tensor) - 用于拉进 `x` 和 `positive` 之间的距离,拉远 `x` 和 `negative` 之间的距离。 + + 输出: + Tensor。如果 `reduction` 为"none",其shape为 :math:`(N)`。否则,将返回Scalar。 + + 异常: + - **TypeError** - `x` 、 `positive` 、 `negative` 或者 `margin` 不是Tensor。 + - **TypeError** - `x` 、 `positive` 或者 `negative` 的数据类型不一致。 + - **TypeError** - `margin` 的数据类型不是float32。 + - **TypeError** - `p` 的数据类型不是int。 + - **TypeError** - `eps` 的数据类型不是float。 + - **TypeError** - `swap` 的数据类型不是bool。 + - **ValueError** - `x` 、 `positive` 和 `negative` 的维度同时小于等于1。 + - **ValueError** - `x` 、 `positive` 或 `negative` 的维度大于等于8。 + - **ValueError** - `margin` 的shape长度不为0。 + - **ValueError** - `x` 、 `positive` 和 `negative` 三者之间的shape无法广播。 + - **ValueError** - `reduction` 不为"mean"、"sum"或"none"。 \ No newline at end of file diff --git a/docs/api/api_python/ops/mindspore.ops.func_triplet_margin_loss.rst b/docs/api/api_python/ops/mindspore.ops.func_triplet_margin_loss.rst new file mode 100644 index 00000000000..445489b895e --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_triplet_margin_loss.rst @@ -0,0 +1,32 @@ +mindspore.ops.triplet_margin_loss +================================== + +.. py:function:: mindspore.ops.triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, reduction='mean') + + 三元组损失函数。 + 详情请查看 :class:`mindspore.nn.TripletMarginLoss` 。 + + 参数: + - **anchor** (Tensor) - 从训练集随机选取的样本。数据类型为BasicType。 + - **positive** (Tensor) - 与 `anchor` 为同一类的样本,数据类型与shape与 `anchor` 一致。 + - **negative** (Tensor) - 与 `anchor` 为异类的样本,数据类型与shape与 `anchor` 一致。 + - **margin** (float,可选) - 用于拉进 `anchor` 和 `positive` 之间的距离,拉远 `anchor` 和 `negative` 之间的距离。默认值:1.0。 + - **p** (int,可选) - 成对距离的范数。默认值:2。 + - **eps** (float,可选) - 防止除数为 0。默认值:1e-06。 + - **swap** (bool,可选) - 距离交换。默认值:False。 + - **reduction** (str,可选) - 指定要应用于输出的缩减方式,取值为"mean"、"sum"或"none"。默认值:"mean"。 + + 返回: + Tensor。如果 `reduction` 为"none",其shape为 :math:`(N)`。否则,将返回Scalar。 + + 异常: + - **TypeError** - `anchor` 、 `positive` 或者 `negative` 不是Tensor。 + - **TypeError** - `anchor` 、 `positive` 或者 `negative` 的数据类型不一致。 + - **TypeError** - `margin` 的数据类型不是float。 + - **TypeError** - `p` 的数据类型不是int。 + - **TypeError** - `eps` 的数据类型不是float。 + - **TypeError** - `swap` 的数据类型不是bool。 + - **ValueError** - `anchor` 、 `positive` 和 `negative` 的维度同时小于等于1。 + - **ValueError** - `anchor` 、 `positive` 或 `negative` 的维度大于等于8。 + - **ValueError** - `anchor` 、 `positive` 和 `negative` 三者之间的shape无法广播。 + - **ValueError** - `reduction` 不为"mean"、"sum"或"none"。 \ No newline at end of file diff --git a/docs/api/api_python_en/mindspore.nn.rst b/docs/api/api_python_en/mindspore.nn.rst index 0f863ed8fb7..4224e2b5dbe 100644 --- a/docs/api/api_python_en/mindspore.nn.rst +++ b/docs/api/api_python_en/mindspore.nn.rst @@ -256,6 +256,7 @@ Loss Function mindspore.nn.SmoothL1Loss mindspore.nn.SoftMarginLoss mindspore.nn.SoftmaxCrossEntropyWithLogits + mindspore.nn.TripletMarginLoss Optimizer --------- diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index 72e420da2d4..711d7be6f8b 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -76,6 +76,7 @@ Loss Functions mindspore.ops.mse_loss mindspore.ops.nll_loss mindspore.ops.smooth_l1_loss + mindspore.ops.triplet_margin_loss Activation Functions ^^^^^^^^^^^^^^^^^^^^ diff --git a/mindspore/python/mindspore/nn/loss/__init__.py b/mindspore/python/mindspore/nn/loss/__init__.py index fb0b01ad402..164fac8cbe7 100644 --- a/mindspore/python/mindspore/nn/loss/__init__.py +++ b/mindspore/python/mindspore/nn/loss/__init__.py @@ -24,7 +24,7 @@ from mindspore.nn.loss.loss import LossBase, L1Loss, CTCLoss, MSELoss, SmoothL1L SoftmaxCrossEntropyWithLogits, BCELoss, MultiMarginLoss, CosineEmbeddingLoss, \ SampledSoftmaxLoss, PoissonNLLLoss, MultiLabelSoftMarginLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \ RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss, \ - HingeEmbeddingLoss, MultilabelMarginLoss + HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss', @@ -32,4 +32,4 @@ __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarg 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'PoissonNLLLoss', 'MultiLabelSoftMarginLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss', 'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss', - 'GaussianNLLLoss', 'HingeEmbeddingLoss'] + 'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss'] diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index 007bbb7b354..54f52c52f1a 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -26,7 +26,6 @@ from mindspore.ops import operations as P from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp -from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp from mindspore.ops import functional as F from mindspore import nn from mindspore.ops.primitive import constexpr @@ -1880,15 +1879,16 @@ class TripletMarginLoss(LossBase): TripletMarginLoss operation. Creates a criterion that measures the triplet loss given an input - tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. - This is used for measuring a relative similarity between samples. A triplet - is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative - examples` respectively). The shapes of all input tensors should be + tensors :math:`x`, :math:`positive`, :math:`negative` and a :math:`margin` with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. + A triplet is composed by `a`, `p` and `n` (i.e., `x`, `positive` and `negative` respectively). + The shapes of all input tensors should be :math:`(N, D)`. - The distance swap is described in detail in the paper `Learning shallow - convolutional feature descriptors with triplet losses` by - V. Balntas, E. Riba et al. + The distance swap is described in detail in the paper + `Learning local feature descriptors with triplets and shallow convolutional neural + networks `_ + by V. Balntas, E. Riba et al. The loss function for each sample in the mini-batch is: @@ -1901,26 +1901,25 @@ class TripletMarginLoss(LossBase): d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p Args: - p (int): The norm degree for pairwise distance. Default: 2. - eps (float): Default: 1e-06. - swap (bool): The distance swap is described in detail in the paper - `Learning shallow convolutional feature descriptors with triplet losses` by - V. Balntas, E. Riba et al. Default: "False". - reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean". + p (int, optional): The norm degree for pairwise distance. Default: 2. + eps (float, optional): Add small value to avoid division by zero. Default: 1e-06. + swap (bool, optional): The distance swap change the negative distance to the distance between positive + sample and negative sample. Default: "False". + reduction (str, optional): Apply specific reduction method to the output: 'none', 'mean', 'sum'. + Default: "mean". Inputs: - **x** (Tensor) - A sample randomly selected from the training set. Data type must be BasicType. - - **positive** (Tensor) - A sample belonging to the same category as x, with the same type and shape as `x`. - - **negative** (Tensor) - A sample belonging to the different class from x, with the same type and shape as `x`. + - **positive** (Tensor) - A sample belonging to the same category as `x`, with the same type and shape as `x`. + - **negative** (Tensor) - A sample belonging to the different class from `x`, with the same type and shape + as `x`. - **margin** (Tensor) - Make a margin between the positive pair and the negative pair. Outputs: - Union[Tensor, Scalar], if `reduction` is "none", its shape is :math:`(N)`. - Otherwise, a scalar value will be returned. + Tensor. If `reduction` is "none", its shape is :math:`(N)`. Otherwise, a scalar value will be returned. Raises: TypeError: If `x` or `positive` or 'negative' or 'margin' is not a Tensor. - TypeError: If dtype of `x` or `positive` or `negative` is not BasicType. TypeError: If dtype of `x`, `positive` and `negative` is not the same. TypeError: If `margin` is not float32. TypeError: If `p` is not an int. @@ -1933,7 +1932,7 @@ class TripletMarginLoss(LossBase): ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. Supported Platforms: - ``Ascend`` ``GPU`` ``CPU` + ``GPU`` Examples: >>> loss = nn.TripletMarginLoss() @@ -1946,12 +1945,16 @@ class TripletMarginLoss(LossBase): 0.8881968 """ - def __init__(self, p=2, swap=False, eps=1e-6, reduction='mean'): + def __init__(self, p=2, swap=False, eps=1e-06, reduction='mean'): super(TripletMarginLoss, self).__init__() - self.triplet_margin_loss = TripletMarginLossOp(p=p, swap=swap, eps=eps, reduction=reduction) + self.p = p + self.swap = swap + self.eps = eps + self.reduction = reduction def construct(self, x, positive, negative, margin): - return self.triplet_margin_loss(x, positive, negative, margin) + return F.triplet_margin_loss(x, positive, negative, margin=margin, p=self.p, + eps=self.eps, swap=self.swap, reduction=self.reduction) @constexpr diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 5efc8ed0aa6..f00660dcdc5 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -465,6 +465,7 @@ from .nn_func import ( lp_pool1d, lp_pool2d, mse_loss, + triplet_margin_loss, msort ) from .linalg_func import ( diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index f93e6da6eeb..669dccaab55 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -35,6 +35,7 @@ from mindspore.ops.operations.nn_ops import MaxUnpool2D, MaxUnpool3D from mindspore.ops.operations.nn_ops import FractionalMaxPoolWithFixedKsize, FractionalMaxPool3DWithFixedKsize from mindspore.ops.operations.nn_ops import PadV3 from mindspore.ops.operations.nn_ops import ChannelShuffle +from mindspore.ops.operations.nn_ops import TripletMarginLoss slice_ = P.Slice() fast_gelu_ = P.FastGeLU() @@ -5611,6 +5612,58 @@ def msort(x): return ops.Sort(axis=0)(x)[0] +def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, reduction='mean'): + """ + TripletMarginLoss operation. + See :class:`mindspore.nn.TripletMarginLoss` for details. + + Args: + anchor (Tensor): A sample randomly selected from the training set. Data type must be BasicType. + positive (Tensor): A sample belonging to the same category as `anchor`, with the same type and shape + as `anchor`. + negative (Tensor): A sample belonging to the different class from `anchor`, with the same type and shape + as `anchor`. + margin (float, optional): Make a margin between the positive pair and the negative pair. Default: 1.0. + p (int, optional): The norm degree for pairwise distance. Default: 2. + eps (float, optional): Add small value to avoid division by zero. Default: 1e-06. + swap (bool, optional): The distance swap change the negative distance to the distance between positive + sample and negative sample. Default: "False". + reduction (str, optional): Apply specific reduction method to the output: 'none', 'mean', 'sum'. + Default: "mean". + + Returns: + Tensor. If `reduction` is "none", its shape is :math:`(N)`. Otherwise, a scalar value will be returned. + + Raises: + TypeError: If `anchor` or `positive` or 'negative' is not a Tensor. + TypeError: If dtype of `anchor`, `positive` and `negative` is not the same. + TypeError: If `margin` is not a float. + TypeError: If `p` is not an int. + TypeError: If `eps` is not a float. + TypeError: If `swap` is not a bool. + ValueError: If dimensions of input `anchor`, `positive` and `negative` are less than or equal to 1 at the + same time. + ValueError: If the dimension of input `anchor` or `positive` or `negative` is bigger than or equal to 8. + ValueError: If shape of `anchor`, `positive` and `negative` cannot broadcast. + ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. + + Supported Platforms: + ``GPU`` + + Examples: + >>> anchor = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32) + >>> positive = Tensor(np.array([[0.4, 0.6], [0.4, 0.6]]), mindspore.float32) + >>> negative = Tensor(np.array([[0.2, 0.9], [0.3, 0.7]]), mindspore.float32) + >>> output = ops.triplet_margin_loss(anchor, positive, negative) + >>> print(output) + 0.8881968 + """ + if not isinstance(margin, Tensor): + margin = Tensor(margin, mstype.float32) + triplet_margin_loss_op = _get_cache_prim(TripletMarginLoss)(p=p, eps=eps, swap=swap, reduction=reduction) + return triplet_margin_loss_op(anchor, positive, negative, margin) + + __all__ = [ 'adaptive_avg_pool1d', 'adaptive_avg_pool2d', @@ -5695,5 +5748,6 @@ __all__ = [ 'max_unpool3d', 'mse_loss', 'msort', + 'triplet_margin_loss', ] __all__.sort() diff --git a/tests/st/nn/test_triplet_margin_loss.py b/tests/st/nn/test_triplet_margin_loss.py new file mode 100644 index 00000000000..95d29a8d07e --- /dev/null +++ b/tests/st/nn/test_triplet_margin_loss.py @@ -0,0 +1,123 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import torch +import numpy as np +import pytest +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_triplet_margin_loss_float64(mode): + """ + Feature: Input type of float64 + Description: Input type of [float64, float64, float64]. + Expectation: success. + """ + context.set_context(mode=mode) + data_type = np.float64 + anchor_array = np.array([[1.3, 20.5, 5.6], + [3.5, 4.8, 7.2], + [0.2, 0.01, 1], + [4, 4.1, 20]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = np.float32(2.0) + p = 0 + swap = True + reduction = "none" + eps = 1e-5 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + ms_margin = Tensor(margin) + triplet_margin_loss = nn.TripletMarginLoss(p=p, eps=eps, swap=swap, reduction=reduction) + output_ms = triplet_margin_loss(anchor, positive, negative, ms_margin) + + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive, + torch_negative, margin=margin, + p=p, eps=eps, swap=swap, + reduction=reduction) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-4, + atol=1e-4, + equal_nan=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_triplet_margin_loss_float32(mode): + """ + Feature: Input type of float32 + Description: Input type of [float32, float32, float32]. + Expectation: success. + """ + context.set_context(mode=mode) + data_type = np.float32 + anchor_array = np.array([[1.3, 20.5, 5.6], + [3.5, 4.8, 7.2], + [0.2, 0.01, 1], + [4, 4.1, 20]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = np.float32(2.0) + p = 1 + swap = False + reduction = "none" + eps = 1e-6 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + ms_margin = Tensor(margin) + triplet_margin_loss = nn.TripletMarginLoss(p=p, eps=eps, swap=swap, reduction=reduction) + output_ms = triplet_margin_loss(anchor, positive, negative, ms_margin) + + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive, + torch_negative, margin=margin, + p=p, eps=eps, swap=swap, + reduction=reduction) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-4, + atol=1e-4, + equal_nan=False) diff --git a/tests/st/ops/test_ops_triplet_margin_loss.py b/tests/st/ops/test_ops_triplet_margin_loss.py new file mode 100644 index 00000000000..f6704cb58ff --- /dev/null +++ b/tests/st/ops/test_ops_triplet_margin_loss.py @@ -0,0 +1,139 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import torch +import numpy as np +import pytest +import mindspore as ms +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.ops as ops +import mindspore.common.dtype as mstype + + +class NetTripletMarginLoss(nn.Cell): + def __init__(self, margin=Tensor(1.0, mstype.float32), p=2, swap=False, eps=1e-6, reduction="mean"): + super(NetTripletMarginLoss, self).__init__() + self.margin = margin + self.p = p + self.swap = swap + self.eps = eps + self.reduction = reduction + + def construct(self, anchor, positive, negative): + return ops.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, + eps=self.eps, swap=self.swap, reduction=self.reduction) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_triplet_margin_loss_float64(mode): + """ + Feature: Input type of float64 + Description: Input type of [float64, float64, float64]. + Expectation: success. + """ + context.set_context(mode=mode) + data_type = np.float64 + anchor_array = np.array([[1.3, 20.5, 5.6], + [3.5, 4.8, 7.2], + [0.2, 0.01, 1], + [4, 4.1, 20]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = np.float32(2.0) + p = 0 + swap = True + reduction = "none" + eps = 1e-5 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + triplet_margin_loss = NetTripletMarginLoss(margin=margin, p=p, eps=eps, + swap=swap, reduction=reduction) + output_ms = triplet_margin_loss(anchor, positive, negative) + + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive, + torch_negative, margin=margin, + p=p, eps=eps, swap=swap, + reduction=reduction) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-4, + atol=1e-4, + equal_nan=False) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_triplet_margin_loss_float32(mode): + """ + Feature: Input type of float32 + Description: Input type of [float32, float32, float32]. + Expectation: success. + """ + context.set_context(mode=mode) + data_type = np.float32 + anchor_array = np.array([[1.3, 20.5, 5.6], + [3.5, 4.8, 7.2], + [0.2, 0.01, 1], + [4, 4.1, 20]]).astype(data_type) + positive_array = np.array([[2., 10., 1.], + [6., 7., 10.], + [13., 4., 1.], + [0.33, -4, -1.5]]).astype(data_type) + negative_array = np.array([[2., 21., 6.], + [68., 9., 10.], + [131., 25., 16.], + [0.31, -0.14, -16.]]).astype(data_type) + margin = np.float32(2.0) + p = 1 + swap = False + reduction = "none" + eps = 1e-6 + + anchor = Tensor(anchor_array) + positive = Tensor(positive_array) + negative = Tensor(negative_array) + triplet_margin_loss = NetTripletMarginLoss(margin=margin, p=p, eps=eps, + swap=swap, reduction=reduction) + output_ms = triplet_margin_loss(anchor, positive, negative) + + torch_anchor = torch.tensor(anchor_array) + torch_positive = torch.tensor(positive_array) + torch_negative = torch.tensor(negative_array) + expect = torch.nn.functional.triplet_margin_loss(torch_anchor, torch_positive, + torch_negative, margin=margin, + p=p, eps=eps, swap=swap, + reduction=reduction) + assert np.allclose(output_ms.asnumpy(), + expect.numpy(), + rtol=1e-4, + atol=1e-4, + equal_nan=False)