From 952e24500046e6cc172a02ad0e47e436759d369c Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Mon, 8 Feb 2021 10:31:02 +0800 Subject: [PATCH] add multclass diceloss --- mindspore/nn/loss/__init__.py | 5 +- mindspore/nn/loss/loss.py | 82 +++++++++++++++++++++++++++++++-- tests/ut/python/nn/test_loss.py | 57 ++++++++++++++++++++++- 3 files changed, 136 insertions(+), 8 deletions(-) diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index 8f1530de588..c64f96c3018 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -21,8 +21,9 @@ It shows how well the model works on a dataset and the optimization target which from .loss import L1Loss, MSELoss, SmoothL1Loss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ - SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss + SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss + __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', - 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss'] + 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 7b8c19c2015..9d83f2728a5 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -21,6 +21,7 @@ from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr from mindspore.ops import _selected_ops from mindspore.nn.cell import Cell +from mindspore.nn.layer.activation import get_activation from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ... import context @@ -329,14 +330,14 @@ class DiceLoss(_Loss): Default: 1e-5. Inputs: - - **y_pred** (Tensor) - Tensor of shape (N, ...). - - **y** (Tensor) - Tensor of shape (N, ...). + - **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. + - **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. Outputs: Tensor, a tensor of shape with the per-example sampled Dice losses. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> loss = nn.DiceLoss(smooth=1e-5) @@ -364,7 +365,7 @@ class DiceLoss(_Loss): single_dice_coeff = (2 * intersection) / (unionset + self.smooth) dice_loss = 1 - single_dice_coeff / label.shape[0] - return dice_loss + return dice_loss.mean() @constexpr @@ -372,6 +373,79 @@ def _check_shape(logits_shape, label_shape): validator.check('logits_shape', logits_shape, 'label_shape', label_shape) +@constexpr +def _check_weights(weight, label): + if weight.shape[0] != label.shape[1]: + raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, " + "and the shape of label is {}.".format(weight.shape, label.shape)) + + +class MultiClassDiceLoss(_Loss): + r""" + When there are multiple classifications, label is transformed into multiple binary classifications by one hot. + For each channel section in the channel, it can be regarded as a binary classification problem, so it can be + obtained through the binary loss of each category, and then the average value. + + Args: + weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`. + ignore_indiex (Union[int, None]): Class index to ignore. + activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'. + Default: 'Softmax'. Choose from: + ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid', + 'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid'] + + Inputs: + - **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. + - **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. + + Outputs: + Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax") + >>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) + >>> output = loss(y_pred, y) + >>> print(output) + [0.7761003] + + Raises: + ValueError: If the shapes are different. + TypeError: If the type of inputs are not Tensor. + """ + def __init__(self, weights=None, ignore_indiex=None, activation="softmax"): + super(MultiClassDiceLoss, self).__init__() + + self.binarydiceloss = DiceLoss(smooth=1e-5) + self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor]) + self.ignore_indiex = ignore_indiex if ignore_indiex is None else \ + validator.check_value_type("ignore_indiex", ignore_indiex, [int]) + self.activation = get_activation(activation) if isinstance(activation, str) else activation + if self.activation is not None and not isinstance(self.activation, Cell): + raise TypeError("The activation must be str or Cell, but got {}.".format(activation)) + self.reshape = P.Reshape() + + def construct(self, logits, label): + _check_shape(logits.shape, label.shape) + total_loss = 0 + + if self.activation is not None: + logits = self.activation(logits) + + for i in range(label.shape[1]): + if i != self.ignore_indiex: + dice_loss = self.binarydiceloss(logits[:, i], label[:, i]) + if self.weights is not None: + _check_weights(self.weights, label) + dice_loss *= self.weights[i] + total_loss += dice_loss + + return total_loss/label.shape[1] + + class SampledSoftmaxLoss(_Loss): r""" Computes the sampled softmax training loss. diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index 0408aabff83..c19f79c1640 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -15,8 +15,8 @@ """ test loss """ import numpy as np import pytest -import mindspore.common.dtype as mstype -import mindspore.nn as nn +from mindspore.common import dtype as mstype +from mindspore import nn from mindspore import Tensor from ..ut_filter import non_graph_engine @@ -107,3 +107,56 @@ def test_dice_loss_check_shape(): y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) with pytest.raises(ValueError): loss(y_pred, y) + + +def test_multi_class_dice_loss(): + """ test_multi_class_dice_loss """ + loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax") + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) + loss(y_pred, y) + + +def test_multi_class_dice_loss_check_shape(): + """ test_multi_class_dice_loss """ + loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax") + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + with pytest.raises(ValueError): + loss(y_pred, y) + + +def test_multi_class_dice_loss_init_weight(): + """ test_multi_class_dice_loss """ + with pytest.raises(TypeError): + loss = nn.MultiClassDiceLoss(weights='1', ignore_indiex=None, activation="softmax") + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + loss(y_pred, y) + + +def test_multi_class_dice_loss_init_ignore_indiex(): + """ test_multi_class_dice_loss """ + with pytest.raises(TypeError): + loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex="2", activation="softmax") + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + loss(y_pred, y) + + +def test_multi_class_dice_loss_init_activation(): + """ test_multi_class_dice_loss """ + with pytest.raises(TypeError): + loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation=2) + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + loss(y_pred, y) + + +def test_multi_class_dice_loss_init_activation2(): + """ test_multi_class_dice_loss """ + with pytest.raises(KeyError): + loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation='www') + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + loss(y_pred, y)