diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index 873aa0d7f66..fe3b9d983f5 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which from .loss import L1Loss, MSELoss, SmoothL1Loss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ - SampledSoftmaxLoss + SampledSoftmaxLoss, DiceLoss __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', - 'CosineEmbeddingLoss', 'SampledSoftmaxLoss'] + 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index f1c0c0a7401..03406f6cc43 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -297,6 +297,67 @@ def _check_label_dtype(labels_dtype, cls_name): validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name) +class DiceLoss(_Loss): + r""" + The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The + value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result + is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area. + The function is shown as follows: + + .. math:: + dice = 1 - \frac{2 * (pred \bigcap true)}{pred \bigcup true} + + Args: + smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. + Default: 1e-5. + threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5. + + Inputs: + - **y_pred** (Tensor) - Tensor of shape (N, C). + - **y** (Tensor) - Tensor of shape (N, C). + + Outputs: + Tensor, a tensor of shape with the per-example sampled Dice losses. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5) + >>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) + >>> output = loss(y_pred, y) + >>> print(output) + [0.77777076] + """ + def __init__(self, smooth=1e-5, threshold=0.5): + super(DiceLoss, self).__init__() + self.smooth = validator.check_positive_float(smooth, "smooth") + self.threshold = validator.check_value_type("threshold", threshold, [float]) + self.reshape = P.Reshape() + + def construct(self, logits, label): + _check_shape(logits.shape, label.shape) + logits = self.cast((logits > self.threshold), mstype.float32) + label = self.cast(label, mstype.float32) + dim = label.shape + pred_flat = self.reshape(logits, (dim[0], -1)) + true_flat = self.reshape(label, (dim[0], -1)) + + intersection = self.reduce_sum((pred_flat * true_flat), 1) + unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1) + + dice = (2 * intersection + self.smooth) / (unionset + self.smooth) + dice_loss = 1 - self.reduce_sum(dice) / dim[0] + + return dice_loss + + +@constexpr +def _check_shape(logits_shape, label_shape): + validator.check('logits_shape', logits_shape, 'label_shape', label_shape) + + class SampledSoftmaxLoss(_Loss): r""" Computes the sampled softmax training loss. diff --git a/mindspore/nn/metrics/dice.py b/mindspore/nn/metrics/dice.py index a4110c1e7ac..56524c7d0fd 100644 --- a/mindspore/nn/metrics/dice.py +++ b/mindspore/nn/metrics/dice.py @@ -26,7 +26,7 @@ class Dice(Metric): The function is shown as follows: .. math:: - \text{dice} = \frac{2 * (\text{pred} \bigcap \text{true})}{\text{pred} \bigcup \text{true}} + dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} Args: smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. @@ -58,7 +58,7 @@ class Dice(Metric): def update(self, *inputs): """ - Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. + Updates the internal evaluation result :math:`y_pred` and :math:`y`. Args: inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the diff --git a/mindspore/nn/metrics/hausdorff_distance.py b/mindspore/nn/metrics/hausdorff_distance.py index 9731a0b3ff4..d9354870111 100644 --- a/mindspore/nn/metrics/hausdorff_distance.py +++ b/mindspore/nn/metrics/hausdorff_distance.py @@ -70,9 +70,9 @@ class HausdorffDistance(Metric): Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows: .. math:: - \text{H}(A, B) = \text{max}[\text{h}(A, B), \text{h}(B, A)] - \text{h}(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \} - \text{h}(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \} + H(A, B) = \text{max}[h(A, B), h(B, A)] + h(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \} + h(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \} Args: distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index c4d1373676d..6b8e746c6ea 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -85,6 +85,7 @@ class dice_coeff(nn.Metric): raise RuntimeError('Total samples num must not be 0.') return self._dice_coeff_sum / float(self._samples_num) + def test_net(data_dir, ckpt_path, cross_valid_ind=1, @@ -102,6 +103,7 @@ def test_net(data_dir, dice_score = model.eval(valid_dataset, dataset_sink_mode=False) print("============== Cross valid dice coeff is:", dice_score) + def get_args(): parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', formatter_class=argparse.ArgumentDefaultsHelpFormatter) diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index f4d97ef1acc..a7c6e524220 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -14,7 +14,8 @@ # ============================================================================ """ test loss """ import numpy as np - +import pytest +import mindspore.common.dtype as mstype import mindspore.nn as nn from mindspore import Tensor from ..ut_filter import non_graph_engine @@ -88,3 +89,22 @@ def test_cosine_embedding_loss(): x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32)) label = Tensor(np.array([1, -1]).astype(np.int32)) loss(x1, x2, label) + + +def test_dice_loss(): + """ test_dice_loss """ + loss = nn.DiceLoss() + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) + # Pass the test if no error is reported + loss(y_pred, y).asnumpy() + + + +def test_dice_loss_check_shape(): + """ test_dice_loss """ + loss = nn.DiceLoss() + y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) + y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32) + with pytest.raises(ValueError): + loss(y_pred, y)