!9083 add dice and hausdroff

From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-07 10:31:42 +08:00 committed by Gitee
commit 09848d7ef3
3 changed files with 154 additions and 0 deletions

View File

@ -25,6 +25,7 @@ from .metric import Metric
from .precision import Precision
from .recall import Recall
from .fbeta import Fbeta, F1
from .dice import Dice
from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from .loss import Loss
@ -38,6 +39,7 @@ __all__ = [
"Recall",
"Fbeta",
"F1",
"Dice",
"TopKCategoricalAccuracy",
"Top1CategoricalAccuracy",
"Top5CategoricalAccuracy",
@ -50,6 +52,7 @@ __factory__ = {
'precision': Precision,
'recall': Recall,
'F1': F1,
'dice': Dice,
'topk': TopKCategoricalAccuracy,
'hausdorff_distance': HausdorffDistance,
'top_1_accuracy': Top1CategoricalAccuracy,

View File

@ -0,0 +1,102 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dice"""
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class Dice(Metric):
r"""
The Dice coefficient is a set similarity metric. It is used to calculate the similarity between two samples. The
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
The function is shown as follows:
.. math::
\text{dice} = \frac{2 * (\text{pred} \bigcap \text{true})}{\text{pred} \bigcup \text{true}}
Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
Default: 1e-5.
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
Examples:
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
>>> metric = Dice(smooth=1e-5, threshold=0.5)
>>> metric.clear()
>>> metric.update(x, y)
>>> dice = metric.eval()
0.22222926
"""
def __init__(self, smooth=1e-5, threshold=0.5):
super(Dice, self).__init__()
self.smooth = validator.check_positive_float(smooth, "smooth")
self.threshold = validator.check_value_type("threshold", threshold, [float])
self.clear()
def clear(self):
"""Clears the internal evaluation result."""
self._dim = 0
self.intersection = 0
self.unionset = 0
def update(self, *inputs):
"""
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the
predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`.
Raises:
ValueError: If the number of the inputs is not 2.
"""
if len(inputs) != 2:
raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
if y_pred.shape != y.shape:
raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, '
'the shape of y is {}.'.format(y_pred.shape, y.shape))
y_pred = (y_pred > self.threshold).astype(int)
self._dim = y.shape
pred_flat = np.reshape(y_pred, (self._dim[0], -1))
true_flat = np.reshape(y, (self._dim[0], -1))
self.intersection = np.sum((pred_flat * true_flat), axis=1)
self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1)
def eval(self):
r"""
Computes the Dice.
Returns:
Float, the computed result.
Raises:
RuntimeError: If the sample size is 0.
"""
if self._dim[0] == 0:
raise RuntimeError('Dice can not be calculated, because the number of samples is 0.')
dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth)
return np.sum(dice) / self._dim[0]

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """test_dice"""
import math
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn.metrics import get_metric_fn, Dice
def test_classification_dice():
"""test_dice"""
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
metric = get_metric_fn('dice')
metric.clear()
metric.update(x, y)
dice = metric.eval()
assert math.isclose(dice, 0.22222926, abs_tol=0.001)
def test_dice_update1():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = Dice(1e-5, 0.5)
metric.clear()
with pytest.raises(ValueError):
metric.update(x)
def test_dice_runtime():
metric = Dice(1e-5, 0.8)
metric.clear()
with pytest.raises(TypeError):
metric.eval()