forked from mindspore-Ecosystem/mindspore
!11195 develop dice loss
From: @lijiaqi0612 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e214c69bc2
|
@ -21,8 +21,8 @@ It shows how well the model works on a dataset and the optimization target which
|
||||||
|
|
||||||
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
from .loss import L1Loss, MSELoss, SmoothL1Loss, \
|
||||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||||
SampledSoftmaxLoss
|
SampledSoftmaxLoss, DiceLoss
|
||||||
|
|
||||||
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss',
|
||||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
'SoftmaxCrossEntropyWithLogits', 'BCELoss',
|
||||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss']
|
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss']
|
||||||
|
|
|
@ -297,6 +297,67 @@ def _check_label_dtype(labels_dtype, cls_name):
|
||||||
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
|
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
|
||||||
|
|
||||||
|
|
||||||
|
class DiceLoss(_Loss):
|
||||||
|
r"""
|
||||||
|
The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The
|
||||||
|
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
|
||||||
|
is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
|
||||||
|
The function is shown as follows:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
dice = 1 - \frac{2 * (pred \bigcap true)}{pred \bigcup true}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||||
|
Default: 1e-5.
|
||||||
|
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **y_pred** (Tensor) - Tensor of shape (N, C).
|
||||||
|
- **y** (Tensor) - Tensor of shape (N, C).
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5)
|
||||||
|
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||||
|
>>> output = loss(y_pred, y)
|
||||||
|
>>> print(output)
|
||||||
|
[0.77777076]
|
||||||
|
"""
|
||||||
|
def __init__(self, smooth=1e-5, threshold=0.5):
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
self.smooth = validator.check_positive_float(smooth, "smooth")
|
||||||
|
self.threshold = validator.check_value_type("threshold", threshold, [float])
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, logits, label):
|
||||||
|
_check_shape(logits.shape, label.shape)
|
||||||
|
logits = self.cast((logits > self.threshold), mstype.float32)
|
||||||
|
label = self.cast(label, mstype.float32)
|
||||||
|
dim = label.shape
|
||||||
|
pred_flat = self.reshape(logits, (dim[0], -1))
|
||||||
|
true_flat = self.reshape(label, (dim[0], -1))
|
||||||
|
|
||||||
|
intersection = self.reduce_sum((pred_flat * true_flat), 1)
|
||||||
|
unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1)
|
||||||
|
|
||||||
|
dice = (2 * intersection + self.smooth) / (unionset + self.smooth)
|
||||||
|
dice_loss = 1 - self.reduce_sum(dice) / dim[0]
|
||||||
|
|
||||||
|
return dice_loss
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_shape(logits_shape, label_shape):
|
||||||
|
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
|
||||||
|
|
||||||
|
|
||||||
class SampledSoftmaxLoss(_Loss):
|
class SampledSoftmaxLoss(_Loss):
|
||||||
r"""
|
r"""
|
||||||
Computes the sampled softmax training loss.
|
Computes the sampled softmax training loss.
|
||||||
|
|
|
@ -26,7 +26,7 @@ class Dice(Metric):
|
||||||
The function is shown as follows:
|
The function is shown as follows:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{dice} = \frac{2 * (\text{pred} \bigcap \text{true})}{\text{pred} \bigcup \text{true}}
|
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||||
|
@ -58,7 +58,7 @@ class Dice(Metric):
|
||||||
|
|
||||||
def update(self, *inputs):
|
def update(self, *inputs):
|
||||||
"""
|
"""
|
||||||
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
|
Updates the internal evaluation result :math:`y_pred` and :math:`y`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
|
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
|
||||||
|
|
|
@ -70,9 +70,9 @@ class HausdorffDistance(Metric):
|
||||||
Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows:
|
Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{H}(A, B) = \text{max}[\text{h}(A, B), \text{h}(B, A)]
|
H(A, B) = \text{max}[h(A, B), h(B, A)]
|
||||||
\text{h}(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}
|
h(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}
|
||||||
\text{h}(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \}
|
h(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
|
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
|
||||||
|
|
|
@ -85,6 +85,7 @@ class dice_coeff(nn.Metric):
|
||||||
raise RuntimeError('Total samples num must not be 0.')
|
raise RuntimeError('Total samples num must not be 0.')
|
||||||
return self._dice_coeff_sum / float(self._samples_num)
|
return self._dice_coeff_sum / float(self._samples_num)
|
||||||
|
|
||||||
|
|
||||||
def test_net(data_dir,
|
def test_net(data_dir,
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
cross_valid_ind=1,
|
cross_valid_ind=1,
|
||||||
|
@ -102,6 +103,7 @@ def test_net(data_dir,
|
||||||
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
|
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
|
||||||
print("============== Cross valid dice coeff is:", dice_score)
|
print("============== Cross valid dice coeff is:", dice_score)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
|
parser = argparse.ArgumentParser(description='Test the UNet on images and target masks',
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test loss """
|
""" test loss """
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
@ -88,3 +89,22 @@ def test_cosine_embedding_loss():
|
||||||
x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32))
|
x2 = Tensor(np.array([[0.4, 1.2], [-0.4, -0.9]]).astype(np.float32))
|
||||||
label = Tensor(np.array([1, -1]).astype(np.int32))
|
label = Tensor(np.array([1, -1]).astype(np.int32))
|
||||||
loss(x1, x2, label)
|
loss(x1, x2, label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dice_loss():
|
||||||
|
""" test_dice_loss """
|
||||||
|
loss = nn.DiceLoss()
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||||
|
# Pass the test if no error is reported
|
||||||
|
loss(y_pred, y).asnumpy()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_dice_loss_check_shape():
|
||||||
|
""" test_dice_loss """
|
||||||
|
loss = nn.DiceLoss()
|
||||||
|
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||||
|
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
loss(y_pred, y)
|
||||||
|
|
Loading…
Reference in New Issue