!11195 develop dice loss

From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-18 12:00:41 +08:00 committed by Gitee
commit e214c69bc2
6 changed files with 91 additions and 8 deletions

View File

@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which
from .loss import L1Loss, MSELoss, SmoothL1Loss, \ from .loss import L1Loss, MSELoss, SmoothL1Loss, \
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss SampledSoftmaxLoss, DiceLoss
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss'] 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss']

View File

@ -297,6 +297,67 @@ def _check_label_dtype(labels_dtype, cls_name):
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name) validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
class DiceLoss(_Loss):
r"""
The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
The function is shown as follows:
.. math::
dice = 1 - \frac{2 * (pred \bigcap true)}{pred \bigcup true}
Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
Default: 1e-5.
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, C).
- **y** (Tensor) - Tensor of shape (N, C).
Outputs:
Tensor, a tensor of shape with the per-example sampled Dice losses.
Supported Platforms:
``Ascend``
Examples:
>>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5)
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.77777076]
"""
def __init__(self, smooth=1e-5, threshold=0.5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.threshold = validator.check_value_type("threshold", threshold, [float])
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_shape(logits.shape, label.shape)
logits = self.cast((logits > self.threshold), mstype.float32)
label = self.cast(label, mstype.float32)
dim = label.shape
pred_flat = self.reshape(logits, (dim[0], -1))
true_flat = self.reshape(label, (dim[0], -1))
intersection = self.reduce_sum((pred_flat * true_flat), 1)
unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1)
dice = (2 * intersection + self.smooth) / (unionset + self.smooth)
dice_loss = 1 - self.reduce_sum(dice) / dim[0]
return dice_loss
@constexpr
def _check_shape(logits_shape, label_shape):
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
class SampledSoftmaxLoss(_Loss): class SampledSoftmaxLoss(_Loss):
r""" r"""
Computes the sampled softmax training loss. Computes the sampled softmax training loss.

View File

@ -26,7 +26,7 @@ class Dice(Metric):
The function is shown as follows: The function is shown as follows:
.. math:: .. math::
\text{dice} = \frac{2 * (\text{pred} \bigcap \text{true})}{\text{pred} \bigcup \text{true}} dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
Args: Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
@ -58,7 +58,7 @@ class Dice(Metric):
def update(self, *inputs): def update(self, *inputs):
""" """
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. Updates the internal evaluation result :math:`y_pred` and :math:`y`.
Args: Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the

View File

@ -70,9 +70,9 @@ class HausdorffDistance(Metric):
Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows: Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows:
.. math:: .. math::
\text{H}(A, B) = \text{max}[\text{h}(A, B), \text{h}(B, A)] H(A, B) = \text{max}[h(A, B), h(B, A)]
\text{h}(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \} h(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}
\text{h}(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \} h(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \}
Args: Args:
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods, distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,

View File

@ -85,6 +85,7 @@ class dice_coeff(nn.Metric):
raise RuntimeError('Total samples num must not be 0.') raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num) return self._dice_coeff_sum / float(self._samples_num)
def test_net(data_dir, def test_net(data_dir,
ckpt_path, ckpt_path,
cross_valid_ind=1, cross_valid_ind=1,
@ -102,6 +103,7 @@ def test_net(data_dir,
dice_score = model.eval(valid_dataset, dataset_sink_mode=False) dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score) print("============== Cross valid dice coeff is:", dice_score)
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)

View File

@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
""" test loss """ """ test loss """
import numpy as np import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
@ -88,3 +89,22 @@ def test_cosine_embedding_loss():
x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32)) x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32))
label = Tensor(np.array([1, -1]).astype(np.int32)) label = Tensor(np.array([1, -1]).astype(np.int32))
loss(x1, x2, label) loss(x1, x2, label)
def test_dice_loss():
""" test_dice_loss """
loss = nn.DiceLoss()
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
# Pass the test if no error is reported
loss(y_pred, y).asnumpy()
def test_dice_loss_check_shape():
""" test_dice_loss """
loss = nn.DiceLoss()
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
with pytest.raises(ValueError):
loss(y_pred, y)