!10939 surface distance

From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-07 11:02:15 +08:00 committed by Gitee
commit 1276715c11
5 changed files with 425 additions and 1 deletions

View File

@ -28,9 +28,12 @@ from .fbeta import Fbeta, F1
from .dice import Dice
from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from .loss import Loss
from .mean_surface_distance import MeanSurfaceDistance
from .root_mean_square_surface_distance import RootMeanSquareDistance
__all__ = [
"names", "get_metric_fn",
"names",
"get_metric_fn",
"Accuracy",
"MAE", "MSE",
"Metric",
@ -44,6 +47,8 @@ __all__ = [
"Top1CategoricalAccuracy",
"Top5CategoricalAccuracy",
"Loss",
"MeanSurfaceDistance",
"RootMeanSquareDistance",
]
__factory__ = {
@ -60,6 +65,8 @@ __factory__ = {
'mae': MAE,
'mse': MSE,
'loss': Loss,
'mean_surface_distance': MeanSurfaceDistance,
'root_mean_square_distance': RootMeanSquareDistance,
}

View File

@ -0,0 +1,137 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MeanSurfaceDistance."""
from scipy.ndimage import morphology
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class MeanSurfaceDistance(Metric):
"""
This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting.
Mean Surface Distance(MSD), the mean of the vector is taken. This tell us how much, on average, the surface varies
between the segmentation and the GT.
Args:
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
"euclidean", "chessboard" or "taxicab". Default: "euclidean".
symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition,
if sets ``symmetric = True``, the average symmetric surface distance between these two inputs
will be returned. Defaults: False.
Examples:
>>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
>>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
>>> metric = nn.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean")
>>> metric.clear()
>>> metric.update(x, y, 0)
>>> mean_average_distance = metric.eval()
>>> print(mean_average_distance)
0.8047378541243649
"""
def __init__(self, symmetric=False, distance_metric="euclidean"):
super(MeanSurfaceDistance, self).__init__()
self.distance_metric_list = ["euclidean", "chessboard", "taxicab"]
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric")
self.symmetric = validator.check_value_type("symmetric", symmetric, [bool])
self.clear()
def clear(self):
"""Clears the internal evaluation result."""
self._y_pred_edges = 0
self._y_edges = 0
self._is_update = False
def _get_surface_distance(self, y_pred_edges, y_edges):
"""
Calculate the surface distances from `y_pred_edges` to `y_edges`.
Args:
y_pred_edges (np.ndarray): the edge of the predictions.
y_edges (np.ndarray): the edge of the ground truth.
"""
if not np.any(y_pred_edges):
return np.array([])
if not np.any(y_edges):
dis = np.full(y_edges.shape, np.inf)
else:
if self.distance_metric == "euclidean":
dis = morphology.distance_transform_edt(~y_edges)
elif self.distance_metric in self.distance_metric_list[-2:]:
dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric)
surface_distance = dis[y_pred_edges]
return surface_distance
def update(self, *inputs):
"""
Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'.
Args:
inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the
predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx`
is int.
Raises:
ValueError: If the number of the inputs is not 3.
"""
if len(inputs) != 3:
raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
label_idx = inputs[2]
if y_pred.size == 0 or y_pred.shape != y.shape:
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))
if y_pred.dtype != bool:
y_pred = y_pred == label_idx
if y.dtype != bool:
y = y == label_idx
self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred
self._y_edges = morphology.binary_erosion(y) ^ y
self._is_update = True
def eval(self):
"""
Calculate mean surface distance.
"""
if self._is_update is False:
raise RuntimeError('Call the update method before calling eval.')
mean_surface_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges)
if mean_surface_distance.shape == (0,):
return np.inf
avg_surface_distance = mean_surface_distance.mean()
if not self.symmetric:
return avg_surface_distance
contrary_mean_surface_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges)
if contrary_mean_surface_distance.shape == (0,):
return np.inf
contrary_avg_surface_distance = contrary_mean_surface_distance.mean()
return np.mean((avg_surface_distance, contrary_avg_surface_distance))

View File

@ -0,0 +1,140 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RootMeanSquareSurfaceDistance."""
from scipy.ndimage import morphology
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric
class RootMeanSquareDistance(Metric):
"""
This function is used to compute the Residual Mean Square Distance from `y_pred` to `y` under the default
setting. Residual Mean Square Distance(RMS), the mean is taken from each of the points in the vector, these
residuals are squared (to remove negative signs), summed, weighted by the mean and then the square-root is taken.
Measured in mm.
Args:
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
"euclidean", "chessboard" or "taxicab". Default: "euclidean".
symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition,
if sets ``symmetric = True``, the average symmetric surface distance between these two inputs
will be returned. Defaults: False.
Examples:
>>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
>>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
>>> metric = nn.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean")
>>> metric.clear()
>>> metric.update(x, y, 0)
>>> root_mean_square_distance = metric.eval()
>>> print(root_mean_square_distance)
1.0000000000000002
"""
def __init__(self, symmetric=False, distance_metric="euclidean"):
super(RootMeanSquareDistance, self).__init__()
self.distance_metric_list = ["euclidean", "chessboard", "taxicab"]
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric")
self.symmetric = validator.check_value_type("symmetric", symmetric, [bool])
self.clear()
def clear(self):
"""Clears the internal evaluation result."""
self._y_pred_edges = 0
self._y_edges = 0
self._is_update = False
def _get_surface_distance(self, y_pred_edges, y_edges):
"""
Calculate the surface distances from `y_pred_edges` to `y_edges`.
Args:
y_pred_edges (np.ndarray): the edge of the predictions.
y_edges (np.ndarray): the edge of the ground truth.
"""
if not np.any(y_pred_edges):
return np.array([])
if not np.any(y_edges):
dis = np.full(y_edges.shape, np.inf)
else:
if self.distance_metric == "euclidean":
dis = morphology.distance_transform_edt(~y_edges)
elif self.distance_metric in self.distance_metric_list[-2:]:
dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric)
surface_distance = dis[y_pred_edges]
return surface_distance
def update(self, *inputs):
"""
Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'.
Args:
inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the
predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx`
is int.
Raises:
ValueError: If the number of the inputs is not 3.
"""
if len(inputs) != 3:
raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
label_idx = inputs[2]
if y_pred.size == 0 or y_pred.shape != y.shape:
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))
if y_pred.dtype != bool:
y_pred = y_pred == label_idx
if y.dtype != bool:
y = y == label_idx
self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred
self._y_edges = morphology.binary_erosion(y) ^ y
self._is_update = True
def eval(self):
"""
Calculate residual mean square surface distance.
"""
if self._is_update is False:
raise RuntimeError('Call the update method before calling eval.')
residual_mean_square_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges)
if residual_mean_square_distance.shape == (0,):
return np.inf
rms_surface_distance = (residual_mean_square_distance**2).mean()
if not self.symmetric:
return rms_surface_distance
contrary_residual_mean_square_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges)
if contrary_residual_mean_square_distance.shape == (0,):
return np.inf
contrary_rms_surface_distance = (contrary_residual_mean_square_distance**2).mean()
rms_distance = np.sqrt(np.mean((rms_surface_distance, contrary_rms_surface_distance)))
return rms_distance

View File

@ -0,0 +1,70 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """test_mean_surface_distance"""
import math
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn.metrics import get_metric_fn, MeanSurfaceDistance
def test_mean_surface_distance():
"""test_mean_surface_distance"""
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
metric = get_metric_fn('mean_surface_distance')
metric.clear()
metric.update(x, y, 0)
distance = metric.eval()
assert math.isclose(distance, 0.8047378541243649, abs_tol=0.001)
def test_mean_surface_distance_update1():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = MeanSurfaceDistance()
metric.clear()
with pytest.raises(ValueError):
metric.update(x)
def test_mean_surface_distance_update2():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
y = Tensor(np.array([1, 0]))
metric = MeanSurfaceDistance()
metric.clear()
with pytest.raises(ValueError):
metric.update(x, y)
def test_mean_surface_distance_init():
with pytest.raises(ValueError):
MeanSurfaceDistance(symmetric=False, distance_metric="eucli")
def test_mean_surface_distance_init2():
with pytest.raises(TypeError):
MeanSurfaceDistance(symmetric=1)
def test_mean_surface_distance_runtime():
metric = MeanSurfaceDistance()
metric.clear()
with pytest.raises(RuntimeError):
metric.eval()

View File

@ -0,0 +1,70 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """test_mean_surface_distance"""
import math
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn.metrics import get_metric_fn, RootMeanSquareDistance
def test_root_mean_square_distance():
"""test_root_mean_square_distance"""
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
metric = get_metric_fn('root_mean_square_distance')
metric.clear()
metric.update(x, y, 0)
distance = metric.eval()
assert math.isclose(distance, 1.0000000000000002, abs_tol=0.001)
def test_root_mean_square_distance_update1():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = RootMeanSquareDistance()
metric.clear()
with pytest.raises(ValueError):
metric.update(x)
def test_root_mean_square_distance_update2():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
y = Tensor(np.array([1, 0]))
metric = RootMeanSquareDistance()
metric.clear()
with pytest.raises(ValueError):
metric.update(x, y)
def test_root_mean_square_distance_init():
with pytest.raises(ValueError):
RootMeanSquareDistance(symmetric=False, distance_metric="eucli")
def test_root_mean_square_distance_init2():
with pytest.raises(TypeError):
RootMeanSquareDistance(symmetric=1)
def test_root_mean_square_distance_runtime():
metric = RootMeanSquareDistance()
metric.clear()
with pytest.raises(RuntimeError):
metric.eval()