!12258 add multclass diceloss
From: @lijiaqi0612 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
4ac1982c58
|
@ -21,8 +21,9 @@ It shows how well the model works on a dataset and the optimization target which
|
||||||
|
|
||||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss
|
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
|
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
|
||||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss']
|
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss']
|
||||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.ops import _selected_ops
|
from mindspore.ops import _selected_ops
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
|
from mindspore.nn.layer.activation import get_activation
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from ... import context
|
from ... import context
|
||||||
|
@ -329,14 +330,14 @@ class DiceLoss(_Loss):
|
||||||
Default: 1e-5.
|
Default: 1e-5.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **y_pred** (Tensor) - Tensor of shape (N, ...).
|
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||||
- **y** (Tensor) - Tensor of shape (N, ...).
|
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> loss = nn.DiceLoss(smooth=1e-5)
|
>>> loss = nn.DiceLoss(smooth=1e-5)
|
||||||
|
@ -364,7 +365,7 @@ class DiceLoss(_Loss):
|
||||||
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
|
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
|
||||||
dice_loss = 1 - single_dice_coeff / label.shape[0]
|
dice_loss = 1 - single_dice_coeff / label.shape[0]
|
||||||
|
|
||||||
return dice_loss
|
return dice_loss.mean()
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -372,6 +373,79 @@ def _check_shape(logits_shape, label_shape):
|
||||||
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
|
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_weights(weight, label):
|
||||||
|
if weight.shape[0] != label.shape[1]:
|
||||||
|
raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, "
|
||||||
|
"and the shape of label is {}.".format(weight.shape, label.shape))
|
||||||
|
|
||||||
|
|
||||||
|
class MultiClassDiceLoss(_Loss):
|
||||||
|
r"""
|
||||||
|
When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
|
||||||
|
For each channel section in the channel, it can be regarded as a binary classification problem, so it can be
|
||||||
|
obtained through the binary loss of each category, and then the average value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`.
|
||||||
|
ignore_indiex (Union[int, None]): Class index to ignore.
|
||||||
|
activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'.
|
||||||
|
Default: 'Softmax'. Choose from:
|
||||||
|
['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid',
|
||||||
|
'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid']
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||||
|
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
|
||||||
|
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||||
|
>>> output = loss(y_pred, y)
|
||||||
|
>>> print(output)
|
||||||
|
[0.7761003]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the shapes are different.
|
||||||
|
TypeError: If the type of inputs are not Tensor.
|
||||||
|
"""
|
||||||
|
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
|
||||||
|
super(MultiClassDiceLoss, self).__init__()
|
||||||
|
|
||||||
|
self.binarydiceloss = DiceLoss(smooth=1e-5)
|
||||||
|
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
|
||||||
|
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
|
||||||
|
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
|
||||||
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||||
|
if self.activation is not None and not isinstance(self.activation, Cell):
|
||||||
|
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, logits, label):
|
||||||
|
_check_shape(logits.shape, label.shape)
|
||||||
|
total_loss = 0
|
||||||
|
|
||||||
|
if self.activation is not None:
|
||||||
|
logits = self.activation(logits)
|
||||||
|
|
||||||
|
for i in range(label.shape[1]):
|
||||||
|
if i != self.ignore_indiex:
|
||||||
|
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
|
||||||
|
if self.weights is not None:
|
||||||
|
_check_weights(self.weights, label)
|
||||||
|
dice_loss *= self.weights[i]
|
||||||
|
total_loss += dice_loss
|
||||||
|
|
||||||
|
return total_loss/label.shape[1]
|
||||||
|
|
||||||
|
|
||||||
class SampledSoftmaxLoss(_Loss):
|
class SampledSoftmaxLoss(_Loss):
|
||||||
r"""
|
r"""
|
||||||
Computes the sampled softmax training loss.
|
Computes the sampled softmax training loss.
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
""" test loss """
|
""" test loss """
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import mindspore.common.dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
import mindspore.nn as nn
|
from mindspore import nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
|
||||||
|
@ -107,3 +107,56 @@ def test_dice_loss_check_shape():
|
||||||
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
loss(y_pred, y)
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss_check_shape():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss_init_weight():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
loss = nn.MultiClassDiceLoss(weights='1', ignore_indiex=None, activation="softmax")
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss_init_ignore_indiex():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex="2", activation="softmax")
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss_init_activation():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation=2)
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_class_dice_loss_init_activation2():
|
||||||
|
""" test_multi_class_dice_loss """
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation='www')
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
Loading…
Reference in New Issue