!12258 add multclass diceloss

From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-20 09:42:31 +08:00 committed by Gitee
commit 4ac1982c58
3 changed files with 136 additions and 8 deletions

View File

@ -21,8 +21,9 @@ It shows how well the model works on a dataset and the optimization target which
from .loss import L1Loss, MSELoss, SmoothL1Loss, \ from .loss import L1Loss, MSELoss, SmoothL1Loss, \
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss'] 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss']

View File

@ -21,6 +21,7 @@ from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.ops import _selected_ops from mindspore.ops import _selected_ops
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ... import context from ... import context
@ -329,14 +330,14 @@ class DiceLoss(_Loss):
Default: 1e-5. Default: 1e-5.
Inputs: Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, ...). - **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, ...). - **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
Outputs: Outputs:
Tensor, a tensor of shape with the per-example sampled Dice losses. Tensor, a tensor of shape with the per-example sampled Dice losses.
Supported Platforms: Supported Platforms:
``Ascend`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> loss = nn.DiceLoss(smooth=1e-5) >>> loss = nn.DiceLoss(smooth=1e-5)
@ -364,7 +365,7 @@ class DiceLoss(_Loss):
single_dice_coeff = (2 * intersection) / (unionset + self.smooth) single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0] dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss return dice_loss.mean()
@constexpr @constexpr
@ -372,6 +373,79 @@ def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape) validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
@constexpr
def _check_weights(weight, label):
if weight.shape[0] != label.shape[1]:
raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, "
"and the shape of label is {}.".format(weight.shape, label.shape))
class MultiClassDiceLoss(_Loss):
r"""
When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
For each channel section in the channel, it can be regarded as a binary classification problem, so it can be
obtained through the binary loss of each category, and then the average value.
Args:
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`.
ignore_indiex (Union[int, None]): Class index to ignore.
activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'.
Default: 'Softmax'. Choose from:
['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid',
'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid']
Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
Outputs:
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.7761003]
Raises:
ValueError: If the shapes are different.
TypeError: If the type of inputs are not Tensor.
"""
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
super(MultiClassDiceLoss, self).__init__()
self.binarydiceloss = DiceLoss(smooth=1e-5)
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if self.activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_shape(logits.shape, label.shape)
total_loss = 0
if self.activation is not None:
logits = self.activation(logits)
for i in range(label.shape[1]):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/label.shape[1]
class SampledSoftmaxLoss(_Loss): class SampledSoftmaxLoss(_Loss):
r""" r"""
Computes the sampled softmax training loss. Computes the sampled softmax training loss.

View File

@ -15,8 +15,8 @@
""" test loss """ """ test loss """
import numpy as np import numpy as np
import pytest import pytest
import mindspore.common.dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn from mindspore import nn
from mindspore import Tensor from mindspore import Tensor
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@ -107,3 +107,56 @@ def test_dice_loss_check_shape():
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
with pytest.raises(ValueError): with pytest.raises(ValueError):
loss(y_pred, y) loss(y_pred, y)
def test_multi_class_dice_loss():
""" test_multi_class_dice_loss """
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
loss(y_pred, y)
def test_multi_class_dice_loss_check_shape():
""" test_multi_class_dice_loss """
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
with pytest.raises(ValueError):
loss(y_pred, y)
def test_multi_class_dice_loss_init_weight():
""" test_multi_class_dice_loss """
with pytest.raises(TypeError):
loss = nn.MultiClassDiceLoss(weights='1', ignore_indiex=None, activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
loss(y_pred, y)
def test_multi_class_dice_loss_init_ignore_indiex():
""" test_multi_class_dice_loss """
with pytest.raises(TypeError):
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex="2", activation="softmax")
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
loss(y_pred, y)
def test_multi_class_dice_loss_init_activation():
""" test_multi_class_dice_loss """
with pytest.raises(TypeError):
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation=2)
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
loss(y_pred, y)
def test_multi_class_dice_loss_init_activation2():
""" test_multi_class_dice_loss """
with pytest.raises(KeyError):
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation='www')
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
loss(y_pred, y)