From 19c5921c065a0dfc118e9e182a2dc2fb1aaa7aa7 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 9 Jun 2020 11:58:13 +0800 Subject: [PATCH] composed op CosineEmbeddingLoss --- mindspore/_checkparam.py | 9 +++++ mindspore/nn/loss/__init__.py | 5 ++- mindspore/nn/loss/loss.py | 70 +++++++++++++++++++++++++++++++++ tests/ut/python/nn/test_loss.py | 8 ++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index d8ca5a9845a..880d26bfad0 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -322,6 +322,15 @@ class Validator: return arg_value raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + @staticmethod + def check_reduce_shape(ori_shape, shape, axis, prim_name): + """Checks whether shape is ori_shape reduced on axis""" + axis = axis if isinstance(axis, Iterable) else (axis,) + exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis] + if list(shape) != exp_shape: + raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be ' + f'{tuple(exp_shape)}, but got {shape}.') + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index f08f5aa721c..ce5870699b0 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -20,8 +20,9 @@ It shows how well the model works on a dataset and the optimization target which """ from .loss import L1Loss, MSELoss, SmoothL1Loss, \ - SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyExpand + SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyExpand, CosineEmbeddingLoss __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftmaxCrossEntropyWithLogits', - 'SoftmaxCrossEntropyExpand'] + 'SoftmaxCrossEntropyExpand', + 'CosineEmbeddingLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index ac419c32c3f..426f111bb20 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -17,9 +17,11 @@ import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore.ops.primitive import constexpr from mindspore.nn.cell import Cell from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel +from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from ... import context @@ -329,3 +331,71 @@ class SoftmaxCrossEntropyExpand(Cell): loss = self.reduce_mean(loss, -1) return loss + + +@constexpr +def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name): + validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name) + +class CosineEmbeddingLoss(_Loss): + r""" + Computes the similarity between two tensors using cosine distance. + + Given two tensors `x1`, `x2`, and a Tensor label `y` with values 1 or -1: + + .. math:: + loss(x_1, x_2, y) = \begin{cases} + 1-cos(x_1, x_2), & \text{if } y = 1\\ + max(0, cos(x_1, x_2)-margin), & \text{if } y = -1\\ + \end{cases} + + Args: + margin (float): Should be in [-1.0, 1.0]. Default 0.0. + reduction (str): Specifies which reduction to apply to the output. It should be one of + "none", "mean", "sum", meaning no reduction, reduce mean or sum on output, respectively. Default "mean". + + Inputs: + - **input_x1** (Tensor) - Input tensor. + - **input_x2** (Tensor) - Its shape and data type should be the same as `input_x1`'s shape and data type. + - **y** (Tensor) - Contains value 1 or -1. Suppose `input_x1` shape is + :math:`(x_1, x_2, x_3,..., x_R)`, then `target` shape should be :math:`(x_1, x_3, x_4, ..., x_R)`. + + Outputs: + - **loss** (Tensor) - If `reduction` is "none", its shape is the same as `y`'s shape, loss value otherwise. + + Examples: + >>> x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]), mindspore.float32) + >>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32) + >>> y = Tensor(np.array([1,-1]), mindspore.int32) + >>> cosine_embedding_loss = P.CosineEmbeddingLoss() + >>> cosine_embedding_loss(x1, x2, target) + [0.0003426671] + """ + def __init__(self, margin=0.0, reduction="mean"): + super(CosineEmbeddingLoss, self).__init__(reduction) + self.reduce_sum = P.ReduceSum() + self.maximum = P.Maximum() + validator.check_value_type("margin", margin, [float], self.cls_name) + self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name) + + def construct(self, x1, x2, y): + F.same_type_shape(x1, x2) + _check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name) + # if target > 0, 1-cosine(x1, x2) + # else, max(0, cosine(x1, x2)-margin) + np_eps = const_utils.get_np_eps(F.dtype(x1)) + eps = F.cast(np_eps, F.dtype(x1)) + prod_sum = self.reduce_sum(x1 * x2, (1,)) + square1 = self.reduce_sum(F.square(x1), (1,)) + eps + square2 = self.reduce_sum(F.square(x2), (1,)) + eps + denom = F.sqrt(square1 * square2) + cosine = prod_sum / denom + + pos_value = 1.0 - cosine + neg_value = self.maximum(cosine - self.margin, 0.0) + zeros = F.zeros_like(cosine) + pos_part = F.select(y == 1, pos_value, zeros) + neg_part = F.select(y == -1, neg_value, zeros) + output_unreduced = pos_part + neg_part + + return self.get_loss(output_unreduced) diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index 21e30068187..7d2329dcfe2 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -62,3 +62,11 @@ def test_SoftmaxCrossEntropyExpand(): logits = Tensor(np.random.randint(0, 9, [100, 10]).astype(np.float32)) labels = Tensor(np.random.randint(0, 9, [10,]).astype(np.float32)) _executor.compile(loss, logits, labels) + +def test_cosine_embedding_loss(): + """ test CosineEmbeddingLoss """ + loss = nn.CosineEmbeddingLoss() + x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]).astype(np.float32)) + x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32)) + label = Tensor(np.array([1, -1]).astype(np.int32)) + loss(x1, x2, label)