forked from mindspore-Ecosystem/mindspore
!11813 modify dice and diceloss
From: @lijiaqi0612 Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian
This commit is contained in:
commit
84afbb510f
|
@ -310,11 +310,10 @@ class DiceLoss(_Loss):
|
|||
Args:
|
||||
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
Default: 1e-5.
|
||||
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
|
||||
|
||||
Inputs:
|
||||
- **y_pred** (Tensor) - Tensor of shape (N, C).
|
||||
- **y** (Tensor) - Tensor of shape (N, C).
|
||||
- **y_pred** (Tensor) - Tensor of shape (N, ...).
|
||||
- **y** (Tensor) - Tensor of shape (N, ...).
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
||||
|
@ -323,32 +322,30 @@ class DiceLoss(_Loss):
|
|||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5)
|
||||
>>> loss = nn.DiceLoss(smooth=1e-5)
|
||||
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
>>> output = loss(y_pred, y)
|
||||
>>> print(output)
|
||||
[0.77777076]
|
||||
[0.7953220862819745]
|
||||
|
||||
Raises:
|
||||
ValueError: If the dimensions are different.
|
||||
TypeError: If the type of inputs are not Tensor.
|
||||
"""
|
||||
def __init__(self, smooth=1e-5, threshold=0.5):
|
||||
def __init__(self, smooth=1e-5):
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = validator.check_positive_float(smooth, "smooth")
|
||||
self.threshold = validator.check_value_type("threshold", threshold, [float])
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, logits, label):
|
||||
_check_shape(logits.shape, label.shape)
|
||||
logits = self.cast((logits > self.threshold), mstype.float32)
|
||||
label = self.cast(label, mstype.float32)
|
||||
dim = label.shape
|
||||
pred_flat = self.reshape(logits, (dim[0], -1))
|
||||
true_flat = self.reshape(label, (dim[0], -1))
|
||||
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
|
||||
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
|
||||
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
|
||||
|
||||
intersection = self.reduce_sum((pred_flat * true_flat), 1)
|
||||
unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1)
|
||||
|
||||
dice = (2 * intersection + self.smooth) / (unionset + self.smooth)
|
||||
dice_loss = 1 - self.reduce_sum(dice) / dim[0]
|
||||
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
|
||||
dice_loss = 1 - single_dice_coeff / label.shape[0]
|
||||
|
||||
return dice_loss
|
||||
|
||||
|
|
|
@ -26,35 +26,35 @@ class Dice(Metric):
|
|||
The function is shown as follows:
|
||||
|
||||
.. math::
|
||||
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
|
||||
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
|
||||
|
||||
Args:
|
||||
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
Default: 1e-5.
|
||||
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
|
||||
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
|
||||
>>> metric = Dice(smooth=1e-5, threshold=0.5)
|
||||
>>> metric = Dice(smooth=1e-5)
|
||||
>>> metric.clear()
|
||||
>>> metric.update(x, y)
|
||||
>>> dice = metric.eval()
|
||||
0.22222926
|
||||
>>> print(dice)
|
||||
0.20467791371802546
|
||||
"""
|
||||
|
||||
def __init__(self, smooth=1e-5, threshold=0.5):
|
||||
def __init__(self, smooth=1e-5):
|
||||
super(Dice, self).__init__()
|
||||
|
||||
self.smooth = validator.check_positive_float(smooth, "smooth")
|
||||
self.threshold = validator.check_value_type("threshold", threshold, [float])
|
||||
self._dice_coeff_sum = 0
|
||||
self._samples_num = 0
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal evaluation result."""
|
||||
self._dim = 0
|
||||
self.intersection = 0
|
||||
self.unionset = 0
|
||||
self._dice_coeff_sum = 0
|
||||
self._samples_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
|
@ -62,7 +62,7 @@ class Dice(Metric):
|
|||
|
||||
Args:
|
||||
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
|
||||
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`.
|
||||
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the inputs is not 2.
|
||||
|
@ -72,17 +72,17 @@ class Dice(Metric):
|
|||
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
self._samples_num += y.shape[0]
|
||||
|
||||
if y_pred.shape != y.shape:
|
||||
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
|
||||
'the shape of y is {}.'.format(y_pred.shape, y.shape))
|
||||
|
||||
y_pred = (y_pred > self.threshold).astype(int)
|
||||
self._dim = y.shape
|
||||
pred_flat = np.reshape(y_pred, (self._dim[0], -1))
|
||||
true_flat = np.reshape(y, (self._dim[0], -1))
|
||||
self.intersection = np.sum((pred_flat * true_flat), axis=1)
|
||||
self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1)
|
||||
intersection = np.dot(y_pred.flatten(), y.flatten())
|
||||
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
||||
|
||||
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
|
||||
self._dice_coeff_sum += single_dice_coeff
|
||||
|
||||
def eval(self):
|
||||
r"""
|
||||
|
@ -92,11 +92,9 @@ class Dice(Metric):
|
|||
Float, the computed result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the sample size is 0.
|
||||
RuntimeError: If the total samples num is 0.
|
||||
"""
|
||||
if self._dim[0] == 0:
|
||||
raise RuntimeError('Dice can not be calculated, because the number of samples is 0.')
|
||||
if self._samples_num == 0:
|
||||
raise RuntimeError('Total samples num must not be 0.')
|
||||
|
||||
dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth)
|
||||
|
||||
return np.sum(dice) / self._dim[0]
|
||||
return self._dice_coeff_sum / float(self._samples_num)
|
||||
|
|
|
@ -29,12 +29,12 @@ def test_classification_dice():
|
|||
metric.update(x, y)
|
||||
dice = metric.eval()
|
||||
|
||||
assert math.isclose(dice, 0.22222926, abs_tol=0.001)
|
||||
assert math.isclose(dice, 0.20467791371802546, abs_tol=0.001)
|
||||
|
||||
|
||||
def test_dice_update1():
|
||||
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||
metric = Dice(1e-5, 0.5)
|
||||
metric = Dice(1e-5)
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -42,8 +42,8 @@ def test_dice_update1():
|
|||
|
||||
|
||||
def test_dice_runtime():
|
||||
metric = Dice(1e-5, 0.8)
|
||||
metric = Dice(1e-5)
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(RuntimeError):
|
||||
metric.eval()
|
||||
|
|
|
@ -97,8 +97,7 @@ def test_dice_loss():
|
|||
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
||||
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
||||
# Pass the test if no error is reported
|
||||
loss(y_pred, y).asnumpy()
|
||||
|
||||
loss(y_pred, y)
|
||||
|
||||
|
||||
def test_dice_loss_check_shape():
|
||||
|
|
Loading…
Reference in New Issue