commit
86bbd1dc98
|
@ -20,8 +20,8 @@ It shows how well the model works on a dataset and the optimization target which
|
|||
"""
|
||||
|
||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||
SoftmaxCrossEntropyWithLogits, CosineEmbeddingLoss
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss
|
||||
|
||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||
'SoftmaxCrossEntropyWithLogits',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
||||
'CosineEmbeddingLoss']
|
||||
|
|
|
@ -262,6 +262,67 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
return self.get_loss(x)
|
||||
|
||||
|
||||
class BCELoss(_Loss):
|
||||
r"""
|
||||
BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels.
|
||||
|
||||
Note:
|
||||
Set the predicted labels as :math:`x`, true labels as :math:`y`, the output loss as :math:`\ell(x, y)`.
|
||||
Let,
|
||||
|
||||
.. math::
|
||||
L = \{l_1,\dots,l_N\}^\top, \quad
|
||||
l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
|
||||
|
||||
Then,
|
||||
|
||||
.. math::
|
||||
\ell(x, y) = \begin{cases}
|
||||
L, & \text{if reduction} = \text{`none';}\\
|
||||
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
|
||||
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
||||
And it must have same shape and data type as `inputs`. Default: None
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
|
||||
Inputs:
|
||||
- **inputs** (Tensor) - The input Tensor. The data type must be float16 or float32.
|
||||
- **labels** (Tensor) - The label Tensor which has same shape and data type as `inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Otherwise, the output is a scalar. default: 'none'
|
||||
|
||||
Examples:
|
||||
>>> weight = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 3.3, 2.2]]), mindspore.float32)
|
||||
>>> loss = nn.BCELoss(weight=weight, reduction='mean')
|
||||
>>> inputs = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([[0, 1, 0], [0, 0, 1]]), mindspore.float32)
|
||||
>>> loss(inputs, labels)
|
||||
"""
|
||||
|
||||
def __init__(self, weight=None, reduction='none'):
|
||||
super(BCELoss, self).__init__()
|
||||
self.binary_cross_entropy = P.BinaryCrossEntropy(reduction=reduction)
|
||||
self.weight_one = weight is None
|
||||
if not self.weight_one:
|
||||
self.weight = weight
|
||||
else:
|
||||
self.ones = P.OnesLike()
|
||||
|
||||
def construct(self, inputs, labels):
|
||||
if self.weight_one:
|
||||
weight = self.ones(inputs)
|
||||
else:
|
||||
weight = self.weight
|
||||
loss = self.binary_cross_entropy(inputs, labels, weight)
|
||||
return loss
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name):
|
||||
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name)
|
||||
|
|
|
@ -53,6 +53,34 @@ def test_SoftmaxCrossEntropyWithLogits_reduce():
|
|||
loss(logits, labels)
|
||||
|
||||
|
||||
def test_BCELoss():
|
||||
""" test_BCELoss """
|
||||
loss = nn.BCELoss()
|
||||
|
||||
inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32))
|
||||
target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32))
|
||||
loss(inputs_data, target_data)
|
||||
|
||||
|
||||
def test_BCELoss_reduce():
|
||||
""" test_BCELoss """
|
||||
loss = nn.BCELoss(reduction='mean')
|
||||
|
||||
inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32))
|
||||
target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32))
|
||||
loss(inputs_data, target_data)
|
||||
|
||||
|
||||
def test_BCELoss_weight():
|
||||
""" test_BCELoss """
|
||||
weight = Tensor(np.array([[1.0, 2.0, 3.0], [2.2, 2.6, 3.9]]).astype(np.float32))
|
||||
loss = nn.BCELoss(weight=weight)
|
||||
|
||||
inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32))
|
||||
target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32))
|
||||
loss(inputs_data, target_data)
|
||||
|
||||
|
||||
def test_cosine_embedding_loss():
|
||||
""" test CosineEmbeddingLoss """
|
||||
loss = nn.CosineEmbeddingLoss()
|
||||
|
|
Loading…
Reference in New Issue