!11813 modify dice and diceloss

From: @lijiaqi0612
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-01-30 11:08:17 +08:00 committed by Gitee
commit 84afbb510f
4 changed files with 39 additions and 45 deletions

View File

@ -310,11 +310,10 @@ class DiceLoss(_Loss):
Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
Default: 1e-5.
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, C).
- **y** (Tensor) - Tensor of shape (N, C).
- **y_pred** (Tensor) - Tensor of shape (N, ...).
- **y** (Tensor) - Tensor of shape (N, ...).
Outputs:
Tensor, a tensor of shape with the per-example sampled Dice losses.
@ -323,32 +322,30 @@ class DiceLoss(_Loss):
``Ascend``
Examples:
>>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5)
>>> loss = nn.DiceLoss(smooth=1e-5)
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.77777076]
[0.7953220862819745]
Raises:
ValueError: If the dimensions are different.
TypeError: If the type of inputs are not Tensor.
"""
def __init__(self, smooth=1e-5, threshold=0.5):
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.threshold = validator.check_value_type("threshold", threshold, [float])
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_shape(logits.shape, label.shape)
logits = self.cast((logits > self.threshold), mstype.float32)
label = self.cast(label, mstype.float32)
dim = label.shape
pred_flat = self.reshape(logits, (dim[0], -1))
true_flat = self.reshape(label, (dim[0], -1))
intersection = self.reduce_sum(self.mul(logits.view(-1), label.view(-1)))
unionset = self.reduce_sum(self.mul(logits.view(-1), logits.view(-1))) + \
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
intersection = self.reduce_sum((pred_flat * true_flat), 1)
unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1)
dice = (2 * intersection + self.smooth) / (unionset + self.smooth)
dice_loss = 1 - self.reduce_sum(dice) / dim[0]
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
return dice_loss

View File

@ -26,35 +26,35 @@ class Dice(Metric):
The function is shown as follows:
.. math::
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
Default: 1e-5.
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
Examples:
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
>>> metric = Dice(smooth=1e-5, threshold=0.5)
>>> metric = Dice(smooth=1e-5)
>>> metric.clear()
>>> metric.update(x, y)
>>> dice = metric.eval()
0.22222926
>>> print(dice)
0.20467791371802546
"""
def __init__(self, smooth=1e-5, threshold=0.5):
def __init__(self, smooth=1e-5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.threshold = validator.check_value_type("threshold", threshold, [float])
self._dice_coeff_sum = 0
self._samples_num = 0
self.clear()
def clear(self):
"""Clears the internal evaluation result."""
self._dim = 0
self.intersection = 0
self.unionset = 0
self._dice_coeff_sum = 0
self._samples_num = 0
def update(self, *inputs):
"""
@ -62,7 +62,7 @@ class Dice(Metric):
Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`.
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, ...)`.
Raises:
ValueError: If the number of the inputs is not 2.
@ -72,17 +72,17 @@ class Dice(Metric):
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
y_pred = (y_pred > self.threshold).astype(int)
self._dim = y.shape
pred_flat = np.reshape(y_pred, (self._dim[0], -1))
true_flat = np.reshape(y, (self._dim[0], -1))
self.intersection = np.sum((pred_flat * true_flat), axis=1)
self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1)
intersection = np.dot(y_pred.flatten(), y.flatten())
unionset = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2 * float(intersection) / float(unionset + self.smooth)
self._dice_coeff_sum += single_dice_coeff
def eval(self):
r"""
@ -92,11 +92,9 @@ class Dice(Metric):
Float, the computed result.
Raises:
RuntimeError: If the sample size is 0.
RuntimeError: If the total samples num is 0.
"""
if self._dim[0] == 0:
raise RuntimeError('Dice can not be calculated, because the number of samples is 0.')
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth)
return np.sum(dice) / self._dim[0]
return self._dice_coeff_sum / float(self._samples_num)

View File

@ -29,12 +29,12 @@ def test_classification_dice():
metric.update(x, y)
dice = metric.eval()
assert math.isclose(dice, 0.22222926, abs_tol=0.001)
assert math.isclose(dice, 0.20467791371802546, abs_tol=0.001)
def test_dice_update1():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = Dice(1e-5, 0.5)
metric = Dice(1e-5)
metric.clear()
with pytest.raises(ValueError):
@ -42,8 +42,8 @@ def test_dice_update1():
def test_dice_runtime():
metric = Dice(1e-5, 0.8)
metric = Dice(1e-5)
metric.clear()
with pytest.raises(TypeError):
with pytest.raises(RuntimeError):
metric.eval()

View File

@ -97,8 +97,7 @@ def test_dice_loss():
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
# Pass the test if no error is reported
loss(y_pred, y).asnumpy()
loss(y_pred, y)
def test_dice_loss_check_shape():