diff --git a/mindspore/nn/metrics/__init__.py b/mindspore/nn/metrics/__init__.py index 1aac2138900..411e8c497e6 100755 --- a/mindspore/nn/metrics/__init__.py +++ b/mindspore/nn/metrics/__init__.py @@ -25,6 +25,7 @@ from .metric import Metric from .precision import Precision from .recall import Recall from .fbeta import Fbeta, F1 +from .dice import Dice from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy from .loss import Loss @@ -38,6 +39,7 @@ __all__ = [ "Recall", "Fbeta", "F1", + "Dice", "TopKCategoricalAccuracy", "Top1CategoricalAccuracy", "Top5CategoricalAccuracy", @@ -50,6 +52,7 @@ __factory__ = { 'precision': Precision, 'recall': Recall, 'F1': F1, + 'dice': Dice, 'topk': TopKCategoricalAccuracy, 'hausdorff_distance': HausdorffDistance, 'top_1_accuracy': Top1CategoricalAccuracy, diff --git a/mindspore/nn/metrics/dice.py b/mindspore/nn/metrics/dice.py new file mode 100644 index 00000000000..a4110c1e7ac --- /dev/null +++ b/mindspore/nn/metrics/dice.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dice""" +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class Dice(Metric): + r""" + The Dice coefficient is a set similarity metric. It is used to calculate the similarity between two samples. The + value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result + is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area. + The function is shown as follows: + + .. math:: + \text{dice} = \frac{2 * (\text{pred} \bigcap \text{true})}{\text{pred} \bigcup \text{true}} + + Args: + smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. + Default: 1e-5. + threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5. + + Examples: + >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) + >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]])) + >>> metric = Dice(smooth=1e-5, threshold=0.5) + >>> metric.clear() + >>> metric.update(x, y) + >>> dice = metric.eval() + 0.22222926 + """ + + def __init__(self, smooth=1e-5, threshold=0.5): + super(Dice, self).__init__() + + self.smooth = validator.check_positive_float(smooth, "smooth") + self.threshold = validator.check_value_type("threshold", threshold, [float]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self._dim = 0 + self.intersection = 0 + self.unionset = 0 + + def update(self, *inputs): + """ + Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. + + Args: + inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. `y_pred` is the + predicted value, `y` is the true value. The shape of `y_pred` and `y` are both :math:`(N, C)`. + + Raises: + ValueError: If the number of the inputs is not 2. + """ + if len(inputs) != 2: + raise ValueError('Dice need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + + if y_pred.shape != y.shape: + raise RuntimeError('y_pred and y should have same the dimension, but the shape of y_pred is{}, ' + 'the shape of y is {}.'.format(y_pred.shape, y.shape)) + + y_pred = (y_pred > self.threshold).astype(int) + self._dim = y.shape + pred_flat = np.reshape(y_pred, (self._dim[0], -1)) + true_flat = np.reshape(y, (self._dim[0], -1)) + self.intersection = np.sum((pred_flat * true_flat), axis=1) + self.unionset = np.sum(pred_flat, axis=1) + np.sum(true_flat, axis=1) + + def eval(self): + r""" + Computes the Dice. + + Returns: + Float, the computed result. + + Raises: + RuntimeError: If the sample size is 0. + """ + if self._dim[0] == 0: + raise RuntimeError('Dice can not be calculated, because the number of samples is 0.') + + dice = (2 * self.intersection + self.smooth) / (self.unionset + self.smooth) + + return np.sum(dice) / self._dim[0] diff --git a/tests/ut/python/metrics/test_dice.py b/tests/ut/python/metrics/test_dice.py new file mode 100644 index 00000000000..4a9fbb7ba53 --- /dev/null +++ b/tests/ut/python/metrics/test_dice.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# """test_dice""" +import math +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.nn.metrics import get_metric_fn, Dice + + +def test_classification_dice(): + """test_dice""" + x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) + y = Tensor(np.array([[0, 1], [1, 0], [0, 1]])) + metric = get_metric_fn('dice') + metric.clear() + metric.update(x, y) + dice = metric.eval() + + assert math.isclose(dice, 0.22222926, abs_tol=0.001) + + +def test_dice_update1(): + x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]])) + metric = Dice(1e-5, 0.5) + metric.clear() + + with pytest.raises(ValueError): + metric.update(x) + + +def test_dice_runtime(): + metric = Dice(1e-5, 0.8) + metric.clear() + + with pytest.raises(TypeError): + metric.eval()