forked from mindspore-Ecosystem/mindspore
surface distance
This commit is contained in:
parent
5602994d48
commit
ed7cf87d0b
|
@ -26,9 +26,12 @@ from .recall import Recall
|
|||
from .fbeta import Fbeta, F1
|
||||
from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||
from .loss import Loss
|
||||
from .mean_surface_distance import MeanSurfaceDistance
|
||||
from .root_mean_square_surface_distance import RootMeanSquareDistance
|
||||
|
||||
__all__ = [
|
||||
"names", "get_metric_fn",
|
||||
"names",
|
||||
"get_metric_fn",
|
||||
"Accuracy",
|
||||
"MAE", "MSE",
|
||||
"Metric",
|
||||
|
@ -40,6 +43,8 @@ __all__ = [
|
|||
"Top1CategoricalAccuracy",
|
||||
"Top5CategoricalAccuracy",
|
||||
"Loss",
|
||||
"MeanSurfaceDistance",
|
||||
"RootMeanSquareDistance",
|
||||
]
|
||||
|
||||
__factory__ = {
|
||||
|
@ -54,6 +59,8 @@ __factory__ = {
|
|||
'mae': MAE,
|
||||
'mse': MSE,
|
||||
'loss': Loss,
|
||||
'mean_surface_distance': MeanSurfaceDistance,
|
||||
'root_mean_square_distance': RootMeanSquareDistance,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MeanSurfaceDistance."""
|
||||
from scipy.ndimage import morphology
|
||||
import numpy as np
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .metric import Metric
|
||||
|
||||
|
||||
class MeanSurfaceDistance(Metric):
|
||||
"""
|
||||
This function is used to compute the Average Surface Distance from `y_pred` to `y` under the default setting.
|
||||
Mean Surface Distance(MSD), the mean of the vector is taken. This tell us how much, on average, the surface varies
|
||||
between the segmentation and the GT.
|
||||
|
||||
Args:
|
||||
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
|
||||
"euclidean", "chessboard" or "taxicab". Default: "euclidean".
|
||||
symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition,
|
||||
if sets ``symmetric = True``, the average symmetric surface distance between these two inputs
|
||||
will be returned. Defaults: False.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
|
||||
>>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
|
||||
>>> metric = nn.MeanSurfaceDistance(symmetric=False, distance_metric="euclidean")
|
||||
>>> metric.clear()
|
||||
>>> metric.update(x, y, 0)
|
||||
>>> mean_average_distance = metric.eval()
|
||||
>>> print(mean_average_distance)
|
||||
0.8047378541243649
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric=False, distance_metric="euclidean"):
|
||||
super(MeanSurfaceDistance, self).__init__()
|
||||
self.distance_metric_list = ["euclidean", "chessboard", "taxicab"]
|
||||
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
|
||||
self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric")
|
||||
self.symmetric = validator.check_value_type("symmetric", symmetric, [bool])
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal evaluation result."""
|
||||
self._y_pred_edges = 0
|
||||
self._y_edges = 0
|
||||
self._is_update = False
|
||||
|
||||
def _get_surface_distance(self, y_pred_edges, y_edges):
|
||||
"""
|
||||
Calculate the surface distances from `y_pred_edges` to `y_edges`.
|
||||
|
||||
Args:
|
||||
y_pred_edges (np.ndarray): the edge of the predictions.
|
||||
y_edges (np.ndarray): the edge of the ground truth.
|
||||
"""
|
||||
|
||||
if not np.any(y_pred_edges):
|
||||
return np.array([])
|
||||
|
||||
if not np.any(y_edges):
|
||||
dis = np.full(y_edges.shape, np.inf)
|
||||
else:
|
||||
if self.distance_metric == "euclidean":
|
||||
dis = morphology.distance_transform_edt(~y_edges)
|
||||
elif self.distance_metric in self.distance_metric_list[-2:]:
|
||||
dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric)
|
||||
|
||||
surface_distance = dis[y_pred_edges]
|
||||
|
||||
return surface_distance
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'.
|
||||
|
||||
Args:
|
||||
inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the
|
||||
predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx`
|
||||
is int.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the inputs is not 3.
|
||||
"""
|
||||
if len(inputs) != 3:
|
||||
raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs)))
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
label_idx = inputs[2]
|
||||
|
||||
if y_pred.size == 0 or y_pred.shape != y.shape:
|
||||
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))
|
||||
|
||||
if y_pred.dtype != bool:
|
||||
y_pred = y_pred == label_idx
|
||||
if y.dtype != bool:
|
||||
y = y == label_idx
|
||||
|
||||
self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred
|
||||
self._y_edges = morphology.binary_erosion(y) ^ y
|
||||
self._is_update = True
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Calculate mean surface distance.
|
||||
"""
|
||||
if self._is_update is False:
|
||||
raise RuntimeError('Call the update method before calling eval.')
|
||||
|
||||
mean_surface_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges)
|
||||
|
||||
if mean_surface_distance.shape == (0,):
|
||||
return np.inf
|
||||
|
||||
avg_surface_distance = mean_surface_distance.mean()
|
||||
|
||||
if not self.symmetric:
|
||||
return avg_surface_distance
|
||||
|
||||
contrary_mean_surface_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges)
|
||||
if contrary_mean_surface_distance.shape == (0,):
|
||||
return np.inf
|
||||
|
||||
contrary_avg_surface_distance = contrary_mean_surface_distance.mean()
|
||||
return np.mean((avg_surface_distance, contrary_avg_surface_distance))
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""RootMeanSquareSurfaceDistance."""
|
||||
from scipy.ndimage import morphology
|
||||
import numpy as np
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from .metric import Metric
|
||||
|
||||
|
||||
class RootMeanSquareDistance(Metric):
|
||||
"""
|
||||
This function is used to compute the Residual Mean Square Distance from `y_pred` to `y` under the default
|
||||
setting. Residual Mean Square Distance(RMS), the mean is taken from each of the points in the vector, these
|
||||
residuals are squared (to remove negative signs), summed, weighted by the mean and then the square-root is taken.
|
||||
Measured in mm.
|
||||
|
||||
Args:
|
||||
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
|
||||
"euclidean", "chessboard" or "taxicab". Default: "euclidean".
|
||||
symmetric (bool): if calculate the symmetric average surface distance between `y_pred` and `y`. In addition,
|
||||
if sets ``symmetric = True``, the average symmetric surface distance between these two inputs
|
||||
will be returned. Defaults: False.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
|
||||
>>> y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
|
||||
>>> metric = nn.RootMeanSquareDistance(symmetric=False, distance_metric="euclidean")
|
||||
>>> metric.clear()
|
||||
>>> metric.update(x, y, 0)
|
||||
>>> root_mean_square_distance = metric.eval()
|
||||
>>> print(root_mean_square_distance)
|
||||
1.0000000000000002
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, symmetric=False, distance_metric="euclidean"):
|
||||
super(RootMeanSquareDistance, self).__init__()
|
||||
self.distance_metric_list = ["euclidean", "chessboard", "taxicab"]
|
||||
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
|
||||
self.distance_metric = validator.check_string(distance_metric, self.distance_metric_list, "distance_metric")
|
||||
self.symmetric = validator.check_value_type("symmetric", symmetric, [bool])
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal evaluation result."""
|
||||
self._y_pred_edges = 0
|
||||
self._y_edges = 0
|
||||
self._is_update = False
|
||||
|
||||
def _get_surface_distance(self, y_pred_edges, y_edges):
|
||||
"""
|
||||
Calculate the surface distances from `y_pred_edges` to `y_edges`.
|
||||
|
||||
Args:
|
||||
y_pred_edges (np.ndarray): the edge of the predictions.
|
||||
y_edges (np.ndarray): the edge of the ground truth.
|
||||
"""
|
||||
|
||||
if not np.any(y_pred_edges):
|
||||
return np.array([])
|
||||
|
||||
if not np.any(y_edges):
|
||||
dis = np.full(y_edges.shape, np.inf)
|
||||
else:
|
||||
if self.distance_metric == "euclidean":
|
||||
dis = morphology.distance_transform_edt(~y_edges)
|
||||
elif self.distance_metric in self.distance_metric_list[-2:]:
|
||||
dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric)
|
||||
|
||||
surface_distance = dis[y_pred_edges]
|
||||
|
||||
return surface_distance
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'.
|
||||
|
||||
Args:
|
||||
inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the
|
||||
predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx`
|
||||
is int.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the inputs is not 3.
|
||||
"""
|
||||
if len(inputs) != 3:
|
||||
raise ValueError('MeanSurfaceDistance need 3 inputs (y_pred, y, label), but got {}.'.format(len(inputs)))
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
label_idx = inputs[2]
|
||||
|
||||
if y_pred.size == 0 or y_pred.shape != y.shape:
|
||||
raise ValueError("y_pred and y should have same shape, but got {}, {}.".format(y_pred.shape, y.shape))
|
||||
|
||||
if y_pred.dtype != bool:
|
||||
y_pred = y_pred == label_idx
|
||||
if y.dtype != bool:
|
||||
y = y == label_idx
|
||||
|
||||
self._y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred
|
||||
self._y_edges = morphology.binary_erosion(y) ^ y
|
||||
self._is_update = True
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Calculate residual mean square surface distance.
|
||||
"""
|
||||
if self._is_update is False:
|
||||
raise RuntimeError('Call the update method before calling eval.')
|
||||
|
||||
residual_mean_square_distance = self._get_surface_distance(self._y_pred_edges, self._y_edges)
|
||||
|
||||
if residual_mean_square_distance.shape == (0,):
|
||||
return np.inf
|
||||
|
||||
rms_surface_distance = (residual_mean_square_distance**2).mean()
|
||||
|
||||
if not self.symmetric:
|
||||
return rms_surface_distance
|
||||
|
||||
contrary_residual_mean_square_distance = self._get_surface_distance(self._y_edges, self._y_pred_edges)
|
||||
if contrary_residual_mean_square_distance.shape == (0,):
|
||||
return np.inf
|
||||
|
||||
contrary_rms_surface_distance = (contrary_residual_mean_square_distance**2).mean()
|
||||
|
||||
rms_distance = np.sqrt(np.mean((rms_surface_distance, contrary_rms_surface_distance)))
|
||||
return rms_distance
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# """test_mean_surface_distance"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.metrics import get_metric_fn, MeanSurfaceDistance
|
||||
|
||||
|
||||
def test_mean_surface_distance():
|
||||
"""test_mean_surface_distance"""
|
||||
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
|
||||
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
|
||||
metric = get_metric_fn('mean_surface_distance')
|
||||
metric.clear()
|
||||
metric.update(x, y, 0)
|
||||
distance = metric.eval()
|
||||
|
||||
assert math.isclose(distance, 0.8047378541243649, abs_tol=0.001)
|
||||
|
||||
|
||||
def test_mean_surface_distance_update1():
|
||||
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||
metric = MeanSurfaceDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
metric.update(x)
|
||||
|
||||
|
||||
def test_mean_surface_distance_update2():
|
||||
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
metric = MeanSurfaceDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
metric.update(x, y)
|
||||
|
||||
|
||||
def test_mean_surface_distance_init():
|
||||
with pytest.raises(ValueError):
|
||||
MeanSurfaceDistance(symmetric=False, distance_metric="eucli")
|
||||
|
||||
|
||||
def test_mean_surface_distance_init2():
|
||||
with pytest.raises(TypeError):
|
||||
MeanSurfaceDistance(symmetric=1)
|
||||
|
||||
|
||||
def test_mean_surface_distance_runtime():
|
||||
metric = MeanSurfaceDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
metric.eval()
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# """test_mean_surface_distance"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.metrics import get_metric_fn, RootMeanSquareDistance
|
||||
|
||||
|
||||
def test_root_mean_square_distance():
|
||||
"""test_root_mean_square_distance"""
|
||||
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
|
||||
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
|
||||
metric = get_metric_fn('root_mean_square_distance')
|
||||
metric.clear()
|
||||
metric.update(x, y, 0)
|
||||
distance = metric.eval()
|
||||
|
||||
assert math.isclose(distance, 1.0000000000000002, abs_tol=0.001)
|
||||
|
||||
|
||||
def test_root_mean_square_distance_update1():
|
||||
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||
metric = RootMeanSquareDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
metric.update(x)
|
||||
|
||||
|
||||
def test_root_mean_square_distance_update2():
|
||||
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||
y = Tensor(np.array([1, 0]))
|
||||
metric = RootMeanSquareDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
metric.update(x, y)
|
||||
|
||||
|
||||
def test_root_mean_square_distance_init():
|
||||
with pytest.raises(ValueError):
|
||||
RootMeanSquareDistance(symmetric=False, distance_metric="eucli")
|
||||
|
||||
|
||||
def test_root_mean_square_distance_init2():
|
||||
with pytest.raises(TypeError):
|
||||
RootMeanSquareDistance(symmetric=1)
|
||||
|
||||
|
||||
def test_root_mean_square_distance_runtime():
|
||||
metric = RootMeanSquareDistance()
|
||||
metric.clear()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
metric.eval()
|
Loading…
Reference in New Issue