forked from mindspore-Ecosystem/mindspore
hausdorff_distance
This commit is contained in:
parent
bfc6cca4f8
commit
9a5b9525de
|
@ -19,6 +19,7 @@ Functions to measure the performance of the machine learning models
|
||||||
on the evaluation dataset. It's used to choose the best model.
|
on the evaluation dataset. It's used to choose the best model.
|
||||||
"""
|
"""
|
||||||
from .accuracy import Accuracy
|
from .accuracy import Accuracy
|
||||||
|
from .hausdorff_distance import HausdorffDistance
|
||||||
from .error import MAE, MSE
|
from .error import MAE, MSE
|
||||||
from .metric import Metric
|
from .metric import Metric
|
||||||
from .precision import Precision
|
from .precision import Precision
|
||||||
|
@ -33,6 +34,7 @@ __all__ = [
|
||||||
"MAE", "MSE",
|
"MAE", "MSE",
|
||||||
"Metric",
|
"Metric",
|
||||||
"Precision",
|
"Precision",
|
||||||
|
"HausdorffDistance",
|
||||||
"Recall",
|
"Recall",
|
||||||
"Fbeta",
|
"Fbeta",
|
||||||
"F1",
|
"F1",
|
||||||
|
@ -49,6 +51,7 @@ __factory__ = {
|
||||||
'recall': Recall,
|
'recall': Recall,
|
||||||
'F1': F1,
|
'F1': F1,
|
||||||
'topk': TopKCategoricalAccuracy,
|
'topk': TopKCategoricalAccuracy,
|
||||||
|
'hausdorff_distance': HausdorffDistance,
|
||||||
'top_1_accuracy': Top1CategoricalAccuracy,
|
'top_1_accuracy': Top1CategoricalAccuracy,
|
||||||
'top_5_accuracy': Top5CategoricalAccuracy,
|
'top_5_accuracy': Top5CategoricalAccuracy,
|
||||||
'mae': MAE,
|
'mae': MAE,
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""HausdorffDistance."""
|
||||||
|
|
||||||
|
from collections import abc
|
||||||
|
from abc import ABCMeta
|
||||||
|
from scipy.ndimage import morphology
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from .metric import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class _ROISpatialData(metaclass=ABCMeta):
|
||||||
|
"""
|
||||||
|
Produce Region Of Interest (ROI). Support to crop ND spatial data. The center and size of the space should be
|
||||||
|
provided, if not, the start and end coordinates of the ROI must be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roi_center (int): The central coordinates of the crop ROI.
|
||||||
|
roi_size (int): The size of the crop ROI.
|
||||||
|
roi_start (int): The start coordinates of the crop ROI.
|
||||||
|
roi_end (int): The end coordinates of the crop ROI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, roi_center=None, roi_size=None, roi_start=None, roi_end=None):
|
||||||
|
|
||||||
|
if roi_center is not None and roi_size is not None:
|
||||||
|
roi_center = np.asarray(roi_center, dtype=np.int16)
|
||||||
|
roi_size = np.asarray(roi_size, dtype=np.int16)
|
||||||
|
self.roi_start = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0)
|
||||||
|
self.roi_end = np.maximum(self.roi_start + roi_size, self.roi_start)
|
||||||
|
else:
|
||||||
|
if roi_start is None or roi_end is None:
|
||||||
|
raise ValueError("Please provide the center coordinates, size or start coordinates and end coordinates"
|
||||||
|
" of ROI.")
|
||||||
|
self.roi_start = np.maximum(np.asarray(roi_start, dtype=np.int16), 0)
|
||||||
|
self.roi_end = np.maximum(np.asarray(roi_end, dtype=np.int16), self.roi_start)
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
"""
|
||||||
|
Transform the data, if the data is channel first, slicing is not applicable to channel dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (np.ndarray): Data to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray, transform result.
|
||||||
|
"""
|
||||||
|
sd = min(len(self.roi_start), len(self.roi_end), len(data.shape[1:]))
|
||||||
|
slices = [slice(None)] + [slice(s, e) for s, e in zip(self.roi_start[:sd], self.roi_end[:sd])]
|
||||||
|
return data[tuple(slices)]
|
||||||
|
|
||||||
|
|
||||||
|
class HausdorffDistance(Metric):
|
||||||
|
r"""
|
||||||
|
Calculate the Hausdorff distance. Hausdorff distance is the maximum and minimum distance between two point sets.
|
||||||
|
Given two feature sets A and B, the Hausdorff distance between two point sets A and B is defined as follows:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{H}(A, B) = \text{max}[\text{h}(A, B), \text{h}(B, A)]
|
||||||
|
\text{h}(A, B) = \underset{a \in A}{\text{max}}\{\underset{b \in B}{\text{min}} \rVert a - b \rVert \}
|
||||||
|
\text{h}(A, B) = \underset{b \in B}{\text{max}}\{\underset{a \in A}{\text{min}} \rVert b - a \rVert \}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance_metric (string): The parameter of calculating Hausdorff distance supports three measurement methods,
|
||||||
|
"euclidean", "chessboard" or "taxicab". Default: "euclidean".
|
||||||
|
percentile (float): Floating point numbers between 0 and 100. Specify the percentile parameter to get the
|
||||||
|
percentile of the Hausdorff distance. Defaults: None.
|
||||||
|
directed (bool): It can be divided into directional and non directional Hausdorff distance,
|
||||||
|
and the default is non directional Hausdorff distance, specify the percentile parameter to get
|
||||||
|
the percentile of the Hausdorff distance. Default: False.
|
||||||
|
crop (bool): Crop input images and only keep the foregrounds. In order to maintain two inputs' shapes,
|
||||||
|
here the bounding box is achieved by (y_pred | y) which represents the union set of two images.
|
||||||
|
Default: True.
|
||||||
|
"""
|
||||||
|
def __init__(self, distance_metric="euclidean", percentile=None, directed=False, crop=True):
|
||||||
|
super(HausdorffDistance, self).__init__()
|
||||||
|
string_list = ["euclidean", "chessboard", "taxicab"]
|
||||||
|
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
|
||||||
|
self.distance_metric = validator.check_string(distance_metric, string_list, "distance_metric")
|
||||||
|
self.percentile = percentile if percentile is None else validator.check_value_type("percentile",
|
||||||
|
percentile, [float])
|
||||||
|
self.directed = directed if directed is None else validator.check_value_type("directed", directed, [bool])
|
||||||
|
self.crop = crop if crop is None else validator.check_value_type("crop", crop, [bool])
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def _is_tuple_rep(self, tup, dim):
|
||||||
|
"""
|
||||||
|
Returns the tup containing the dim value by shortening or repeating the input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When tup is a sequence and tup length is not dim.
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = None
|
||||||
|
if not self._is_iterable_sequence(tup):
|
||||||
|
result = (tup,) * dim
|
||||||
|
elif len(tup) == dim:
|
||||||
|
result = tuple(tup)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
raise ValueError(f"Sequence must have length {dim}, but got {len(tup)}.")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _is_tuple(self, inputs):
|
||||||
|
"""
|
||||||
|
Returns a tuple of inputs.
|
||||||
|
"""
|
||||||
|
if not self._is_iterable_sequence(inputs):
|
||||||
|
inputs = (inputs,)
|
||||||
|
|
||||||
|
return tuple(inputs)
|
||||||
|
|
||||||
|
def _is_iterable_sequence(self, inputs):
|
||||||
|
"""
|
||||||
|
Determine if the input is an iterable sequence and it is not a string.
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, Tensor):
|
||||||
|
return int(inputs.dim()) > 0
|
||||||
|
return isinstance(inputs, abc.Iterable) and not isinstance(inputs, str)
|
||||||
|
|
||||||
|
def _create_space_bounding_box(self, image, func=lambda x: x > 0, channel_indices=None, margin=0):
|
||||||
|
"""
|
||||||
|
The position of the space bounding box that generates the foreground in an image with start end.
|
||||||
|
The user can define any function to select the desired foreground from the whole image or the specified channel.
|
||||||
|
It can also add margins to each size of the bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: source image to generate bounding box from.
|
||||||
|
func: function to select expected foreground, default is to select values > 0.
|
||||||
|
channel_indices: if defined, select foreground only on the specified channels
|
||||||
|
of image. if None, select foreground on the whole image.
|
||||||
|
margin: add margin value to spatial dims of the bounding box, if only a single value is provided,
|
||||||
|
use it for all dims.
|
||||||
|
"""
|
||||||
|
data = image[[*(self._is_tuple(channel_indices))]] if channel_indices is not None else image
|
||||||
|
data = np.any(func(data), axis=0)
|
||||||
|
nonzero_idx = np.nonzero(data)
|
||||||
|
margin = self._is_tuple_rep(margin, data.ndim)
|
||||||
|
|
||||||
|
box_start = list()
|
||||||
|
box_end = list()
|
||||||
|
for i in range(data.ndim):
|
||||||
|
if nonzero_idx[i].size <= 0:
|
||||||
|
raise ValueError("did not find nonzero index at the spatial dim {}".format(i))
|
||||||
|
box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i]))
|
||||||
|
box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1))
|
||||||
|
return box_start, box_end
|
||||||
|
|
||||||
|
def _calculate_percent_hausdorff_distance(self, y_pred_edges, y_edges):
|
||||||
|
"""
|
||||||
|
Calculate the directed Hausdorff distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred_edges (np.ndarray): the edge of the predictions.
|
||||||
|
y_edges (np.ndarray): the edge of the ground truth.
|
||||||
|
"""
|
||||||
|
surface_distance = self._get_surface_distance(y_pred_edges, y_edges)
|
||||||
|
|
||||||
|
if surface_distance.shape == (0,):
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
if not self.percentile:
|
||||||
|
return surface_distance.max()
|
||||||
|
if 0 <= self.percentile <= 100:
|
||||||
|
return np.percentile(surface_distance, self.percentile)
|
||||||
|
|
||||||
|
raise ValueError(f"percentile should be a value between 0 and 100, get {self.percentile}.")
|
||||||
|
|
||||||
|
def _get_surface_distance(self, y_pred_edges, y_edges):
|
||||||
|
"""
|
||||||
|
Calculate the surface distances from `y_pred_edges` to `y_edges`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred_edges (np.ndarray): the edge of the predictions.
|
||||||
|
y_edges (np.ndarray): the edge of the ground truth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not np.any(y_pred_edges):
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
if not np.any(y_edges):
|
||||||
|
dis = np.inf * np.ones_like(y_edges)
|
||||||
|
else:
|
||||||
|
if self.distance_metric == "euclidean":
|
||||||
|
dis = morphology.distance_transform_edt(~y_edges)
|
||||||
|
elif self.distance_metric == "chessboard" or self.distance_metric == "taxicab":
|
||||||
|
dis = morphology.distance_transform_cdt(~y_edges, metric=self.distance_metric)
|
||||||
|
|
||||||
|
surface_distance = dis[y_pred_edges]
|
||||||
|
|
||||||
|
return surface_distance
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clears the internal evaluation result."""
|
||||||
|
self.y_pred_edges = 0
|
||||||
|
self.y_edges = 0
|
||||||
|
self._is_update = False
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""
|
||||||
|
Updates the internal evaluation result 'y_pred', 'y' and 'label_idx'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input 'y_pred', 'y' and 'label_idx'. 'y_pred' and 'y' are Tensor or numpy.ndarray. 'y_pred' is the
|
||||||
|
predicted binary image. 'y' is the actual binary image. 'label_idx', the data type of `label_idx`
|
||||||
|
is int.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of the inputs is not 3.
|
||||||
|
"""
|
||||||
|
if len(inputs) != 3:
|
||||||
|
raise ValueError('HausdorffDistance need 3 inputs (y_pred, y, label), but got {}'.format(len(inputs)))
|
||||||
|
y_pred = self._convert_data(inputs[0])
|
||||||
|
y = self._convert_data(inputs[1])
|
||||||
|
label_idx = inputs[2]
|
||||||
|
|
||||||
|
if y_pred.size == 0 or y_pred.shape != y.shape:
|
||||||
|
raise ValueError("Labelfields should have the same shape, but got {}, {}".format(y_pred.shape, y.shape))
|
||||||
|
|
||||||
|
y_pred = (y_pred == label_idx) if y_pred.dtype is not bool else y_pred
|
||||||
|
y = (y == label_idx) if y.dtype is not bool else y
|
||||||
|
|
||||||
|
res1, res2 = None, None
|
||||||
|
if self.crop:
|
||||||
|
if not np.any(y_pred | y):
|
||||||
|
res1 = np.zeros_like(y_pred)
|
||||||
|
res2 = np.zeros_like(y)
|
||||||
|
|
||||||
|
y_pred, y = np.expand_dims(y_pred, 0), np.expand_dims(y, 0)
|
||||||
|
box_start, box_end = self._create_space_bounding_box(y_pred | y)
|
||||||
|
cropper = _ROISpatialData(roi_start=box_start, roi_end=box_end)
|
||||||
|
y_pred, y = np.squeeze(cropper(y_pred)), np.squeeze(cropper(y))
|
||||||
|
|
||||||
|
self.y_pred_edges = morphology.binary_erosion(y_pred) ^ y_pred if res1 is None else res1
|
||||||
|
self.y_edges = morphology.binary_erosion(y) ^ y if res2 is None else res2
|
||||||
|
self._is_update = True
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""
|
||||||
|
Calculate the no-directed or directed Hausdorff distance.
|
||||||
|
"""
|
||||||
|
if self._is_update is False:
|
||||||
|
raise RuntimeError('Call the update method before calling eval.')
|
||||||
|
|
||||||
|
hd = self._calculate_percent_hausdorff_distance(self.y_pred_edges, self.y_edges)
|
||||||
|
if self.directed:
|
||||||
|
return hd
|
||||||
|
|
||||||
|
hd2 = self._calculate_percent_hausdorff_distance(self.y_edges, self.y_pred_edges)
|
||||||
|
return max(hd, hd2)
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
# """test_hausdorff_distance"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn.metrics import get_metric_fn, HausdorffDistance
|
||||||
|
|
||||||
|
|
||||||
|
def test_hausdorff_distance():
|
||||||
|
"""test_hausdorff_distance"""
|
||||||
|
x = Tensor(np.array([[3, 0, 1], [1, 3, 0], [1, 0, 2]]))
|
||||||
|
y = Tensor(np.array([[0, 2, 1], [1, 2, 1], [0, 0, 1]]))
|
||||||
|
metric = get_metric_fn('hausdorff_distance')
|
||||||
|
metric.clear()
|
||||||
|
metric.update(x, y, 0)
|
||||||
|
distance = metric.eval()
|
||||||
|
|
||||||
|
assert math.isclose(distance, 1.4142135623730951, abs_tol=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hausdorff_distance_update1():
|
||||||
|
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||||
|
metric = HausdorffDistance()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
metric.update(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hausdorff_distance_update2():
|
||||||
|
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
|
||||||
|
y = Tensor(np.array([1, 0]))
|
||||||
|
metric = HausdorffDistance()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
metric.update(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hausdorff_distance_init():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
HausdorffDistance(distance_metric="eucli", percentile=None, directed=False, crop=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hausdorff_distance_runtime():
|
||||||
|
metric = HausdorffDistance()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
metric.eval()
|
Loading…
Reference in New Issue