forked from mindspore-Ecosystem/mindspore
commit
9893b3d128
|
@ -19,11 +19,11 @@ Cells of loss function. Loss function in machine learning is the target of the m
|
|||
It shows how well the model works on a dataset and the optimization target which the optimizer is searching.
|
||||
"""
|
||||
|
||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss
|
||||
|
||||
|
||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss']
|
||||
|
|
|
@ -13,11 +13,13 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""loss"""
|
||||
import mindspore
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import nn
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.nn.cell import Cell
|
||||
|
@ -896,3 +898,110 @@ class BCEWithLogitsLoss(_Loss):
|
|||
pos_weight = ones_input
|
||||
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight)
|
||||
return loss
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_ndim(predict_nidm, target_ndim):
|
||||
validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_channel_and_shape(target, predict):
|
||||
if target not in (predict, 1):
|
||||
raise ValueError("The target must have a channel or the same shape as predict.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_predict_channel(predict):
|
||||
if predict == 1:
|
||||
raise NotImplementedError("Single channel prediction is not supported.")
|
||||
|
||||
|
||||
class FocalLoss(_Loss):
|
||||
r"""
|
||||
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
|
||||
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of
|
||||
classification difficulty.
|
||||
|
||||
Args:
|
||||
gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
|
||||
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights
|
||||
are applied. Default: None.
|
||||
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
||||
If "none", do not perform reduction. Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes.
|
||||
Its value is greater than 1.
|
||||
- **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the
|
||||
expected target of this loss should be the class index within the range of [0, C-1],
|
||||
where C is the number of classes.
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor of shape with the per-example sampled Focal losses.
|
||||
|
||||
Raises:
|
||||
TypeError: If the data type of ``gamma`` is not float..
|
||||
TypeError: If ``weight`` is not a Parameter.
|
||||
ValueError: If ``target`` shape different from ``predict``.
|
||||
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``.
|
||||
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Example:
|
||||
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
>>> target = Tensor([[1], [1], [0]], mstype.int32)
|
||||
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
|
||||
>>> output = focalloss(inputs, labels)
|
||||
>>> print(output)
|
||||
0.33365273
|
||||
"""
|
||||
|
||||
def __init__(self, weight=None, gamma=2.0, reduction='mean'):
|
||||
super(FocalLoss, self).__init__(reduction=reduction)
|
||||
|
||||
self.gamma = validator.check_value_type("gamma", gamma, [float])
|
||||
if weight is not None and not isinstance(weight, Tensor):
|
||||
raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))
|
||||
self.weight = weight
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.gather_d = P.GatherD()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.tile = P.Tile()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, predict, target):
|
||||
targets = target
|
||||
_check_ndim(predict.ndim, targets.ndim)
|
||||
_check_channel_and_shape(targets.shape[1], predict.shape[1])
|
||||
_check_predict_channel(predict.shape[1])
|
||||
|
||||
if predict.ndim > 2:
|
||||
predict = predict.view(predict.shape[0], predict.shape[1], -1)
|
||||
targets = targets.view(targets.shape[0], targets.shape[1], -1)
|
||||
else:
|
||||
predict = self.expand_dims(predict, 2)
|
||||
targets = self.expand_dims(targets, 2)
|
||||
|
||||
log_probability = nn.LogSoftmax(1)(predict)
|
||||
|
||||
if target.shape[1] == 1:
|
||||
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))
|
||||
log_probability = self.squeeze(log_probability)
|
||||
|
||||
probability = F.exp(log_probability)
|
||||
|
||||
if self.weight is not None:
|
||||
convert_weight = self.weight[None, :, None]
|
||||
convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2]))
|
||||
if target.shape[1] == 1:
|
||||
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))
|
||||
convert_weight = self.squeeze(convert_weight)
|
||||
probability = log_probability * convert_weight
|
||||
|
||||
weight = F.pows(-probability + 1.0, self.gamma)
|
||||
if target.shape[1] == 1:
|
||||
loss = (-weight * log_probability).mean(axis=1)
|
||||
else:
|
||||
loss = (-weight * targets * log_probability).mean(axis=-1)
|
||||
|
||||
return self.get_loss(loss)
|
||||
|
|
|
@ -52,18 +52,30 @@ geswitch = P.GeSwitch()
|
|||
addn = P.AddN()
|
||||
absolute = P.Abs()
|
||||
tensor_add = P.Add()
|
||||
add = tensor_add
|
||||
neg_tensor = P.Neg()
|
||||
tensor_lt = P.Less()
|
||||
less = tensor_lt
|
||||
tensor_le = P.LessEqual()
|
||||
le = tensor_le
|
||||
tensor_gt = P.Greater()
|
||||
gt = tensor_gt
|
||||
tensor_ge = P.GreaterEqual()
|
||||
ge = tensor_ge
|
||||
tensor_sub = P.Sub()
|
||||
sub = tensor_sub
|
||||
tensor_mul = P.Mul()
|
||||
mul = tensor_mul
|
||||
tensor_div = P.RealDiv()
|
||||
div = tensor_div
|
||||
tensor_floordiv = P.FloorDiv()
|
||||
floordiv = tensor_floordiv
|
||||
tensor_pow = P.Pow()
|
||||
pows = tensor_pow
|
||||
tensor_mod = P.FloorMod()
|
||||
floormod = tensor_mod
|
||||
tensor_exp = P.Exp()
|
||||
exp = tensor_exp
|
||||
tensor_expm1 = P.Expm1()
|
||||
strided_slice = P.StridedSlice()
|
||||
same_type_shape = P.SameTypeShape()
|
||||
|
|
|
@ -91,6 +91,50 @@ def test_cosine_embedding_loss():
|
|||
loss(x1, x2, label)
|
||||
|
||||
|
||||
def test_focal_loss():
|
||||
""" test_FocalLoss """
|
||||
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
x2 = Tensor([[1], [1], [0]], mstype.int32)
|
||||
focalloss = nn.FocalLoss()
|
||||
focalloss(x1, x2)
|
||||
|
||||
|
||||
def test_focal_loss_gamma():
|
||||
""" test_FocalLoss """
|
||||
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
x2 = Tensor([[1], [1], [0]], mstype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
focalloss = nn.FocalLoss(weight=None, gamma="mmm", reduction='mean')
|
||||
focalloss(x1, x2)
|
||||
|
||||
|
||||
def test_focal_loss_weight():
|
||||
""" test_FocalLoss """
|
||||
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
x2 = Tensor([[1], [1]], mstype.int32)
|
||||
with pytest.raises(TypeError):
|
||||
focalloss = nn.FocalLoss(weight='a', gamma=2.0, reduction='mean')
|
||||
focalloss(x1, x2)
|
||||
|
||||
|
||||
def test_focal_loss_reduction():
|
||||
""" test_FocalLoss """
|
||||
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
x2 = Tensor([[1], [1], [0]], mstype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='m')
|
||||
focalloss(x1, x2)
|
||||
|
||||
|
||||
def test_focal_loss_input():
|
||||
""" test_FocalLoss """
|
||||
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
||||
x2 = Tensor([[1]], mstype.int32)
|
||||
focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')
|
||||
with pytest.raises(ValueError):
|
||||
focalloss(x1, x2)
|
||||
|
||||
|
||||
def test_dice_loss():
|
||||
""" test_dice_loss """
|
||||
loss = nn.DiceLoss()
|
||||
|
|
Loading…
Reference in New Issue