composed op CosineEmbeddingLoss

This commit is contained in:
zhaozhenlong 2020-06-09 11:58:13 +08:00
parent 95c8de970e
commit 19c5921c06
4 changed files with 90 additions and 2 deletions

View File

@ -322,6 +322,15 @@ class Validator:
return arg_value
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod
def check_reduce_shape(ori_shape, shape, axis, prim_name):
"""Checks whether shape is ori_shape reduced on axis"""
axis = axis if isinstance(axis, Iterable) else (axis,)
exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
if list(shape) != exp_shape:
raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be '
f'{tuple(exp_shape)}, but got {shape}.')
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""

View File

@ -20,8 +20,9 @@ It shows how well the model works on a dataset and the optimization target which
"""
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyExpand
SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyExpand, CosineEmbeddingLoss
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
'SoftmaxCrossEntropyWithLogits',
'SoftmaxCrossEntropyExpand']
'SoftmaxCrossEntropyExpand',
'CosineEmbeddingLoss']

View File

@ -17,9 +17,11 @@ import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from ... import context
@ -329,3 +331,71 @@ class SoftmaxCrossEntropyExpand(Cell):
loss = self.reduce_mean(loss, -1)
return loss
@constexpr
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name):
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name)
class CosineEmbeddingLoss(_Loss):
r"""
Computes the similarity between two tensors using cosine distance.
Given two tensors `x1`, `x2`, and a Tensor label `y` with values 1 or -1:
.. math::
loss(x_1, x_2, y) = \begin{cases}
1-cos(x_1, x_2), & \text{if } y = 1\\
max(0, cos(x_1, x_2)-margin), & \text{if } y = -1\\
\end{cases}
Args:
margin (float): Should be in [-1.0, 1.0]. Default 0.0.
reduction (str): Specifies which reduction to apply to the output. It should be one of
"none", "mean", "sum", meaning no reduction, reduce mean or sum on output, respectively. Default "mean".
Inputs:
- **input_x1** (Tensor) - Input tensor.
- **input_x2** (Tensor) - Its shape and data type should be the same as `input_x1`'s shape and data type.
- **y** (Tensor) - Contains value 1 or -1. Suppose `input_x1` shape is
:math:`(x_1, x_2, x_3,..., x_R)`, then `target` shape should be :math:`(x_1, x_3, x_4, ..., x_R)`.
Outputs:
- **loss** (Tensor) - If `reduction` is "none", its shape is the same as `y`'s shape, loss value otherwise.
Examples:
>>> x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]), mindspore.float32)
>>> x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]), mindspore.float32)
>>> y = Tensor(np.array([1,-1]), mindspore.int32)
>>> cosine_embedding_loss = P.CosineEmbeddingLoss()
>>> cosine_embedding_loss(x1, x2, target)
[0.0003426671]
"""
def __init__(self, margin=0.0, reduction="mean"):
super(CosineEmbeddingLoss, self).__init__(reduction)
self.reduce_sum = P.ReduceSum()
self.maximum = P.Maximum()
validator.check_value_type("margin", margin, [float], self.cls_name)
self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name)
def construct(self, x1, x2, y):
F.same_type_shape(x1, x2)
_check_reduced_shape_valid(F.shape(x1), F.shape(y), (1,), self.cls_name)
# if target > 0, 1-cosine(x1, x2)
# else, max(0, cosine(x1, x2)-margin)
np_eps = const_utils.get_np_eps(F.dtype(x1))
eps = F.cast(np_eps, F.dtype(x1))
prod_sum = self.reduce_sum(x1 * x2, (1,))
square1 = self.reduce_sum(F.square(x1), (1,)) + eps
square2 = self.reduce_sum(F.square(x2), (1,)) + eps
denom = F.sqrt(square1 * square2)
cosine = prod_sum / denom
pos_value = 1.0 - cosine
neg_value = self.maximum(cosine - self.margin, 0.0)
zeros = F.zeros_like(cosine)
pos_part = F.select(y == 1, pos_value, zeros)
neg_part = F.select(y == -1, neg_value, zeros)
output_unreduced = pos_part + neg_part
return self.get_loss(output_unreduced)

View File

@ -62,3 +62,11 @@ def test_SoftmaxCrossEntropyExpand():
logits = Tensor(np.random.randint(0, 9, [100, 10]).astype(np.float32))
labels = Tensor(np.random.randint(0, 9, [10,]).astype(np.float32))
_executor.compile(loss, logits, labels)
def test_cosine_embedding_loss():
""" test CosineEmbeddingLoss """
loss = nn.CosineEmbeddingLoss()
x1 = Tensor(np.array([[0.3, 0.8], [0.4, 0.3]]).astype(np.float32))
x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32))
label = Tensor(np.array([1, -1]).astype(np.int32))
loss(x1, x2, label)