forked from mindspore-Ecosystem/mindspore
commit
e341b7ee97
|
@ -32,6 +32,9 @@ from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5Categori
|
||||||
from .loss import Loss
|
from .loss import Loss
|
||||||
from .mean_surface_distance import MeanSurfaceDistance
|
from .mean_surface_distance import MeanSurfaceDistance
|
||||||
from .root_mean_square_surface_distance import RootMeanSquareDistance
|
from .root_mean_square_surface_distance import RootMeanSquareDistance
|
||||||
|
from .bleu_score import BleuScore
|
||||||
|
from .cosine_similarity import CosineSimilarity
|
||||||
|
from .occlusion_sensitivity import OcclusionSensitivity
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"names",
|
"names",
|
||||||
|
@ -43,6 +46,9 @@ __all__ = [
|
||||||
"HausdorffDistance",
|
"HausdorffDistance",
|
||||||
"Recall",
|
"Recall",
|
||||||
"Fbeta",
|
"Fbeta",
|
||||||
|
"BleuScore",
|
||||||
|
"CosineSimilarity",
|
||||||
|
"OcclusionSensitivity",
|
||||||
"F1",
|
"F1",
|
||||||
"Dice",
|
"Dice",
|
||||||
"ROC",
|
"ROC",
|
||||||
|
@ -64,6 +70,9 @@ __factory__ = {
|
||||||
'dice': Dice,
|
'dice': Dice,
|
||||||
'roc': ROC,
|
'roc': ROC,
|
||||||
'auc': auc,
|
'auc': auc,
|
||||||
|
'bleu_score': BleuScore,
|
||||||
|
'cosine_similarity': CosineSimilarity,
|
||||||
|
'occlusion_sensitivity': OcclusionSensitivity,
|
||||||
'topk': TopKCategoricalAccuracy,
|
'topk': TopKCategoricalAccuracy,
|
||||||
'hausdorff_distance': HausdorffDistance,
|
'hausdorff_distance': HausdorffDistance,
|
||||||
'top_1_accuracy': Top1CategoricalAccuracy,
|
'top_1_accuracy': Top1CategoricalAccuracy,
|
||||||
|
|
|
@ -1,103 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""Evaluation."""
|
|
||||||
import numpy as np
|
|
||||||
from .metric import Metric
|
|
||||||
|
|
||||||
_eval_types = {'classification', 'multilabel'}
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationBase(Metric):
|
|
||||||
"""
|
|
||||||
Base class of evaluation.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Please refer to the definition of class `Accuracy`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the input type is not classification or multilabel.
|
|
||||||
"""
|
|
||||||
def __init__(self, eval_type):
|
|
||||||
super(EvaluationBase, self).__init__()
|
|
||||||
if eval_type not in _eval_types:
|
|
||||||
raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type))
|
|
||||||
self._type = eval_type
|
|
||||||
|
|
||||||
def _check_shape(self, y_pred, y):
|
|
||||||
"""
|
|
||||||
Checks the shapes of y_pred and y.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_pred (Tensor): Predict array.
|
|
||||||
y (Tensor): Target array.
|
|
||||||
"""
|
|
||||||
if self._type == 'classification':
|
|
||||||
if y_pred.ndim != y.ndim + 1:
|
|
||||||
raise ValueError('Classification case, dims of y_pred equal dims of y add 1, '
|
|
||||||
'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim))
|
|
||||||
if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]:
|
|
||||||
raise ValueError('Classification case, y_pred shape and y shape can not match. '
|
|
||||||
'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape))
|
|
||||||
else:
|
|
||||||
if y_pred.ndim != y.ndim:
|
|
||||||
raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} '
|
|
||||||
'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim))
|
|
||||||
if y_pred.shape != y.shape:
|
|
||||||
raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'.
|
|
||||||
format(self._type, y_pred.shape, y.shape))
|
|
||||||
|
|
||||||
def _check_value(self, y_pred, y):
|
|
||||||
"""
|
|
||||||
Checks the values of y_pred and y.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_pred (Tensor): Predict array.
|
|
||||||
y (Tensor): Target array.
|
|
||||||
"""
|
|
||||||
if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()):
|
|
||||||
raise ValueError('For multilabel case, input value must be 1 or 0.')
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""
|
|
||||||
A interface describes the behavior of clearing the internal evaluation result.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
All subclasses must override this interface.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def update(self, *inputs):
|
|
||||||
"""
|
|
||||||
A interface describes the behavior of updating the internal evaluation result.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
All subclasses must override this interface.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: The first item is predicted array and the second item is target array.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
"""
|
|
||||||
A interface describes the behavior of computing the evaluation result.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
All subclasses must override this interface.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Accuracy."""
|
"""Accuracy."""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ._evaluation import EvaluationBase
|
from .metric import EvaluationBase
|
||||||
|
|
||||||
|
|
||||||
class Accuracy(EvaluationBase):
|
class Accuracy(EvaluationBase):
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""BleuScore."""
|
||||||
|
from collections import Counter
|
||||||
|
import numpy as np
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from .metric import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class BleuScore(Metric):
|
||||||
|
"""
|
||||||
|
Calculate BLEU score of machine translated text with one or more references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_gram (int): The n_gram value ranged from 1 to 4. Default: 4
|
||||||
|
smooth (bool): Whether or not to apply smoothing. Default: False
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']]
|
||||||
|
>>> reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'],
|
||||||
|
>>> ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]]
|
||||||
|
>>> metric = BleuScore()
|
||||||
|
>>> metric.clear()
|
||||||
|
>>> metric.update(candidate_corpus, reference_corpus)
|
||||||
|
>>> bleu_score = metric.eval()
|
||||||
|
0.5946035575013605
|
||||||
|
"""
|
||||||
|
def __init__(self, n_gram=4, smooth=False):
|
||||||
|
super().__init__()
|
||||||
|
self.n_gram = validator.check_value_type("n_gram", n_gram, [int])
|
||||||
|
if self.n_gram > 4 or self.n_gram < 1:
|
||||||
|
raise ValueError('The n_gram value ranged from 1 to 4, but got {}'.format(n_gram))
|
||||||
|
|
||||||
|
self.smooth = validator.check_value_type("smooth", smooth, [bool])
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the internal evaluation result."""
|
||||||
|
self._numerator = np.zeros(self.n_gram)
|
||||||
|
self._denominator = np.zeros(self.n_gram)
|
||||||
|
self._precision_scores = np.zeros(self.n_gram)
|
||||||
|
self._c = 0.0
|
||||||
|
self._r = 0.0
|
||||||
|
self._trans_len = 0
|
||||||
|
self._ref_len = 0
|
||||||
|
self._is_update = False
|
||||||
|
|
||||||
|
def _count_ngram(self, ngram_input_list, n_gram):
|
||||||
|
"""
|
||||||
|
Counting how many times each word appears in a given text with ngram.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ngram_input_list (list): A list of translated text or reference texts.
|
||||||
|
n_gram (int): gram value ranged 1 to 4.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
ngram_counter: a collections.Counter object of ngram.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ngram_counter = Counter()
|
||||||
|
|
||||||
|
for i in range(1, n_gram + 1):
|
||||||
|
for j in range(len(ngram_input_list) - i + 1):
|
||||||
|
ngram_key = tuple(ngram_input_list[j:(i + j)])
|
||||||
|
ngram_counter[ngram_key] += 1
|
||||||
|
|
||||||
|
return ngram_counter
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""
|
||||||
|
Updates the internal evaluation result with `candidate_corpus` and `reference_corpus`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input `candidate_corpus` and ``reference_corpus`. `candidate_corpus` and `reference_corpus` are a
|
||||||
|
list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is
|
||||||
|
an iterable of iterables of reference corpus.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of input is not 2.
|
||||||
|
"""
|
||||||
|
if len(inputs) != 2:
|
||||||
|
raise ValueError('The bleu_score need 2 inputs (candidate_corpus, reference_corpus), '
|
||||||
|
'but got {}'.format(len(inputs)))
|
||||||
|
candidate_corpus = inputs[0]
|
||||||
|
reference_corpus = inputs[1]
|
||||||
|
if len(candidate_corpus) != len(reference_corpus):
|
||||||
|
raise ValueError('translate_corpus and reference_corpus should be equal in length, '
|
||||||
|
'but got {} {}'.format(len(candidate_corpus), len(reference_corpus)))
|
||||||
|
|
||||||
|
for (candidate, references) in zip(candidate_corpus, reference_corpus):
|
||||||
|
self._c += len(candidate)
|
||||||
|
ref_len_list = [len(ref) for ref in references]
|
||||||
|
ref_len_diff = [abs(len(candidate) - x) for x in ref_len_list]
|
||||||
|
self._r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
|
||||||
|
translation_counter = self._count_ngram(candidate, self.n_gram)
|
||||||
|
reference_counter = Counter()
|
||||||
|
|
||||||
|
for ref in references:
|
||||||
|
reference_counter |= self._count_ngram(ref, self.n_gram)
|
||||||
|
|
||||||
|
ngram_counter_clip = translation_counter & reference_counter
|
||||||
|
|
||||||
|
for counter_clip in ngram_counter_clip:
|
||||||
|
self._numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
|
||||||
|
|
||||||
|
for counter in translation_counter:
|
||||||
|
self._denominator[len(counter) - 1] += translation_counter[counter]
|
||||||
|
|
||||||
|
self._trans_len = np.array(self._c)
|
||||||
|
self._ref_len = np.array(self._r)
|
||||||
|
self._is_update = True
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""
|
||||||
|
Computes the bleu score.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy with bleu score.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._is_update is False:
|
||||||
|
raise RuntimeError('Call the update method before calling eval.')
|
||||||
|
if min(self._numerator) == 0.0:
|
||||||
|
return np.array(0.0)
|
||||||
|
|
||||||
|
if self.smooth:
|
||||||
|
precision_scores = np.add(self._numerator, np.ones(self.n_gram)) / np.add(self._denominator,
|
||||||
|
np.ones(self.n_gram))
|
||||||
|
else:
|
||||||
|
precision_scores = self._numerator / self._denominator
|
||||||
|
|
||||||
|
log_precision_scores = np.array([1.0 / self.n_gram] * self.n_gram) * np.log(precision_scores)
|
||||||
|
geometric_mean = np.exp(np.sum(log_precision_scores))
|
||||||
|
brevity_penalty = np.array(1.0) if self._c > self._r else np.exp(1 - (self._ref_len / self._trans_len))
|
||||||
|
bleu = brevity_penalty * geometric_mean
|
||||||
|
|
||||||
|
return bleu
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""CosineSimilarity."""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from .metric import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class CosineSimilarity(Metric):
|
||||||
|
"""
|
||||||
|
Computes representation similarity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
similarity (str): 'dot' or 'cosine'. Default: 'cosine'
|
||||||
|
reduction (str): 'none', 'sum', 'mean' (all along dim -1). Default: 'none'
|
||||||
|
zero_diagonal (bool): if True, the diagonals are set to zero. Default: True
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A square matrix (input1, input1) with the similarity scores between all elements.
|
||||||
|
If sum or mean are used, then returns (b, 1) with the reduced value for each row
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> test_data = np.random.randn(4, 8)
|
||||||
|
>>> metric = CosineSimilarity()
|
||||||
|
>>> metric.clear()
|
||||||
|
>>> metric.update(test_data)
|
||||||
|
>>> square_matrix = metric.eval()
|
||||||
|
[[0. -0.14682831 0.19102288 -0.36204537]
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True):
|
||||||
|
super().__init__()
|
||||||
|
similarity_list = ['dot', 'cosine']
|
||||||
|
reduction_list = ['none', 'sum', 'mean']
|
||||||
|
similarity = validator.check_value_type("similarity", similarity, [str])
|
||||||
|
self.similarity = validator.check_string(similarity, similarity_list, "similarity")
|
||||||
|
reduction = validator.check_value_type("reduction", reduction, [str])
|
||||||
|
self.reduction = validator.check_string(reduction, reduction_list, "reduction")
|
||||||
|
self.zero_diagonal = validator.check_value_type("zero_diagonal", zero_diagonal, [bool])
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clears the internal evaluation result."""
|
||||||
|
self.sqr_mtx_res = 0
|
||||||
|
self._is_update = False
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""
|
||||||
|
Updates the internal evaluation result with 'input1'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: input_data `input1`. The input_data is a `Tensor`or an array.
|
||||||
|
"""
|
||||||
|
input_data = self._convert_data(inputs[0])
|
||||||
|
|
||||||
|
if self.similarity == 'cosine':
|
||||||
|
data = np.linalg.norm(input_data, ord=2, axis=1)
|
||||||
|
input_data = input_data / np.expand_dims(data, 1)
|
||||||
|
|
||||||
|
self.sqr_mtx_res = np.dot(input_data, input_data.transpose(1, 0))
|
||||||
|
self._is_update = True
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""
|
||||||
|
Computes the Cosine_Similarity square matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A square matrix.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self._is_update:
|
||||||
|
raise RuntimeError('Call the update method before calling eval.')
|
||||||
|
|
||||||
|
if self.zero_diagonal:
|
||||||
|
np.fill_diagonal(self.sqr_mtx_res, 0)
|
||||||
|
|
||||||
|
if self.reduction == 'mean':
|
||||||
|
self.sqr_mtx_res = np.mean(self.sqr_mtx_res, axis=-1)
|
||||||
|
|
||||||
|
if self.reduction == 'sum':
|
||||||
|
self.sqr_mtx_res = np.sum(self.sqr_mtx_res, axis=-1)
|
||||||
|
|
||||||
|
return self.sqr_mtx_res
|
|
@ -17,6 +17,8 @@ from abc import ABCMeta, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
|
||||||
|
_eval_types = {'classification', 'multilabel'}
|
||||||
|
|
||||||
|
|
||||||
class Metric(metaclass=ABCMeta):
|
class Metric(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
|
@ -140,3 +142,87 @@ class Metric(metaclass=ABCMeta):
|
||||||
inputs: A variable-length input argument list.
|
inputs: A variable-length input argument list.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError('Must define update function to use this base class')
|
raise NotImplementedError('Must define update function to use this base class')
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationBase(Metric):
|
||||||
|
"""
|
||||||
|
Base class of evaluation.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Please refer to the definition of class `Accuracy`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the input type is not classification or multilabel.
|
||||||
|
"""
|
||||||
|
def __init__(self, eval_type):
|
||||||
|
super(EvaluationBase, self).__init__()
|
||||||
|
if eval_type not in _eval_types:
|
||||||
|
raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type))
|
||||||
|
self._type = eval_type
|
||||||
|
|
||||||
|
def _check_shape(self, y_pred, y):
|
||||||
|
"""
|
||||||
|
Checks the shapes of y_pred and y.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred (Tensor): Predict array.
|
||||||
|
y (Tensor): Target array.
|
||||||
|
"""
|
||||||
|
if self._type == 'classification':
|
||||||
|
if y_pred.ndim != y.ndim + 1:
|
||||||
|
raise ValueError('Classification case, dims of y_pred equal dims of y add 1, '
|
||||||
|
'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim))
|
||||||
|
if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]:
|
||||||
|
raise ValueError('Classification case, y_pred shape and y shape can not match. '
|
||||||
|
'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape))
|
||||||
|
else:
|
||||||
|
if y_pred.ndim != y.ndim:
|
||||||
|
raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} '
|
||||||
|
'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim))
|
||||||
|
if y_pred.shape != y.shape:
|
||||||
|
raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'.
|
||||||
|
format(self._type, y_pred.shape, y.shape))
|
||||||
|
|
||||||
|
def _check_value(self, y_pred, y):
|
||||||
|
"""
|
||||||
|
Checks the values of y_pred and y.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred (Tensor): Predict array.
|
||||||
|
y (Tensor): Target array.
|
||||||
|
"""
|
||||||
|
if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()):
|
||||||
|
raise ValueError('For multilabel case, input value must be 1 or 0.')
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
A interface describes the behavior of clearing the internal evaluation result.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All subclasses must override this interface.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""
|
||||||
|
A interface describes the behavior of updating the internal evaluation result.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All subclasses must override this interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: The first item is predicted array and the second item is target array.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""
|
||||||
|
A interface describes the behavior of computing the evaluation result.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All subclasses must override this interface.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""OcclusionSensitivity."""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore._checkparam import Validator as validator
|
||||||
|
from .metric import Metric
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import trange
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
trange = range
|
||||||
|
|
||||||
|
|
||||||
|
class OcclusionSensitivity(Metric):
|
||||||
|
"""
|
||||||
|
This function is used to calculate the occlusion sensitivity of the model for a given image.
|
||||||
|
Occlusion sensitivity refers to how the probability of a given prediction changes with the change of the occluded
|
||||||
|
part of the image.
|
||||||
|
|
||||||
|
For a given result, the output probability is the probability of a region.
|
||||||
|
|
||||||
|
The higher the value in the output image, the greater the decline of certainty, indicating that
|
||||||
|
the occluded area is more important in the decision-making process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_val (float): What values need to be entered in the image when a part of the image is occluded. Default: 0.0.
|
||||||
|
margin (Union[int, Sequence]): Create a cuboid / cube around the voxel you want to occlude. Default: 2.
|
||||||
|
n_batch (int): number of images in a batch before inference. Default: 128.
|
||||||
|
b_box (Sequence): Bounding box on which to perform the analysis. The output image will also match in size.
|
||||||
|
There should be a minimum and maximum for all dimensions except batch:
|
||||||
|
``[min1, max1, min2, max2,...]``. If no bounding box is supplied, this will be the same size
|
||||||
|
as the input image. If a bounding box is used, the output image will be cropped to this size.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> class DenseNet(nn.Cell):
|
||||||
|
>>> def init(self):
|
||||||
|
>>> super(DenseNet, self).init()
|
||||||
|
>>> w = np.array([[0.1, 0.8, 0.1, 0.1],[1, 1, 1, 1]]).astype(np.float32)
|
||||||
|
>>> b = np.array([0.3, 0.6]).astype(np.float32)
|
||||||
|
>>> self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b))
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, x):
|
||||||
|
>>> return self.dense(x)
|
||||||
|
>>>
|
||||||
|
>>> model = DenseNet()
|
||||||
|
>>> test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32)
|
||||||
|
>>> label = np.array(1).astype(np.int32)
|
||||||
|
>>> metric = OcclusionSensitivity()
|
||||||
|
>>> metric.clear()
|
||||||
|
>>> metric.update(model, test_data, label)
|
||||||
|
>>> score = metric.eval()
|
||||||
|
[0.29999995 0.6 1 0.9]
|
||||||
|
"""
|
||||||
|
def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_val = validator.check_value_type("pad_val", pad_val, [float])
|
||||||
|
self.margin = validator.check_value_type("margin", margin, [int, Sequence])
|
||||||
|
self.n_batch = validator.check_value_type("n_batch", n_batch, [int])
|
||||||
|
self.b_box = b_box if b_box is None else validator.check_value_type("b_box", b_box, [list])
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clears the internal evaluation result."""
|
||||||
|
self._baseline = 0
|
||||||
|
self._sensitivity_im = 0
|
||||||
|
self._is_update = False
|
||||||
|
|
||||||
|
def _check_input_bounding_box(self, b_box, im_shape):
|
||||||
|
"""Check that the bounding box (if supplied) is as expected."""
|
||||||
|
# If no bounding box has been supplied, set min and max to None
|
||||||
|
if b_box is None:
|
||||||
|
b_box_min = b_box_max = None
|
||||||
|
else:
|
||||||
|
if len(b_box) != 2 * len(im_shape):
|
||||||
|
raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)")
|
||||||
|
|
||||||
|
b_box_min = np.array(b_box[::2])
|
||||||
|
b_box_max = np.array(b_box[1::2])
|
||||||
|
b_box_min[b_box_min < 0] = 0
|
||||||
|
b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1
|
||||||
|
if np.any(b_box_max >= im_shape):
|
||||||
|
raise ValueError("Max bounding box should be < image size for all values")
|
||||||
|
if np.any(b_box_min > b_box_max):
|
||||||
|
raise ValueError("Min bounding box should be <= max for all values")
|
||||||
|
|
||||||
|
return b_box_min, b_box_max
|
||||||
|
|
||||||
|
def _append_to_sensitivity_im(self, model, batch_images, batch_ids, sensitivity_im):
|
||||||
|
"""For a given number of images, the probability of predicting a given label is obtained. Attach to previous
|
||||||
|
assessment."""
|
||||||
|
batch_images = np.vstack(batch_images)
|
||||||
|
batch_ids = np.expand_dims(batch_ids, 1)
|
||||||
|
model_numpy = model(Tensor(batch_images)).asnumpy()
|
||||||
|
first_indices = np.arange(batch_ids.shape[0])[:, None]
|
||||||
|
scores = model_numpy[first_indices, batch_ids]
|
||||||
|
if sensitivity_im.size == 0:
|
||||||
|
return np.vstack(scores)
|
||||||
|
return np.vstack((sensitivity_im, scores))
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
"""
|
||||||
|
Updates input, including `model`, `y_pred` and `label`.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **model** (nn.Cell) - classification model to use for inference.
|
||||||
|
- **y_pred** (Union[Tensor, list, np.ndarray]) - image to test. Should be tensor consisting of 1 batch,
|
||||||
|
can be 2- or 3D.
|
||||||
|
- **label** (Union[int, Tensor]) - classification label to check for changes (normally the true label,
|
||||||
|
but doesn't have to be
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of input is not 3.
|
||||||
|
"""
|
||||||
|
if len(inputs) != 3:
|
||||||
|
raise ValueError('occlusion_sensitivity need 3 inputs (model, y_pred, y), but got {}'.format(len(inputs)))
|
||||||
|
|
||||||
|
model = inputs[0]
|
||||||
|
y_pred = self._convert_data(inputs[1])
|
||||||
|
label = self._convert_data(inputs[2])
|
||||||
|
model = validator.check_value_type("model", model, [nn.Cell])
|
||||||
|
|
||||||
|
if y_pred.shape[0] > 1:
|
||||||
|
raise RuntimeError("Expected batch size of 1.")
|
||||||
|
|
||||||
|
if isinstance(label, int):
|
||||||
|
label = np.array([[label]], dtype=int)
|
||||||
|
# If the label is a tensor, make sure there's only 1 element
|
||||||
|
elif np.prod(label.shape) != y_pred.shape[0]:
|
||||||
|
raise RuntimeError("Expected as many labels as batches.")
|
||||||
|
|
||||||
|
y_pred_shape = np.array(y_pred.shape[1:])
|
||||||
|
b_box_min, b_box_max = self._check_input_bounding_box(self.b_box, y_pred_shape)
|
||||||
|
|
||||||
|
temp = model(Tensor(y_pred)).asnumpy()
|
||||||
|
self._baseline = temp[0, label].item()
|
||||||
|
|
||||||
|
batch_images = []
|
||||||
|
batch_ids = []
|
||||||
|
|
||||||
|
sensitivity_im = np.empty(0, dtype=float)
|
||||||
|
|
||||||
|
output_im_shape = y_pred_shape if self.b_box is None else b_box_max - b_box_min + 1
|
||||||
|
num_required_predictions = np.prod(output_im_shape)
|
||||||
|
|
||||||
|
for i in trange(num_required_predictions):
|
||||||
|
idx = np.unravel_index(i, output_im_shape)
|
||||||
|
if b_box_min is not None:
|
||||||
|
idx += b_box_min
|
||||||
|
|
||||||
|
min_idx = [max(0, i - self.margin) for i in idx]
|
||||||
|
max_idx = [min(j, i + self.margin) for i, j in zip(idx, y_pred_shape)]
|
||||||
|
|
||||||
|
occlu_im = y_pred.copy()
|
||||||
|
occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val
|
||||||
|
|
||||||
|
batch_images.append(occlu_im)
|
||||||
|
batch_ids.append(label)
|
||||||
|
|
||||||
|
if len(batch_images) == self.n_batch or i == num_required_predictions - 1:
|
||||||
|
sensitivity_im = self._append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im)
|
||||||
|
batch_images = []
|
||||||
|
batch_ids = []
|
||||||
|
|
||||||
|
self._sensitivity_im = sensitivity_im.reshape(output_im_shape)
|
||||||
|
self._is_update = True
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""
|
||||||
|
Computes the occlusion_sensitivity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy ndarray.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self._is_update:
|
||||||
|
raise RuntimeError('Call the update method before calling eval.')
|
||||||
|
|
||||||
|
sensitivity = self._baseline - np.squeeze(self._sensitivity_im)
|
||||||
|
|
||||||
|
return sensitivity
|
|
@ -18,7 +18,7 @@ import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ._evaluation import EvaluationBase
|
from .metric import EvaluationBase
|
||||||
|
|
||||||
|
|
||||||
class Precision(EvaluationBase):
|
class Precision(EvaluationBase):
|
||||||
|
|
|
@ -18,7 +18,7 @@ import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
from ._evaluation import EvaluationBase
|
from .metric import EvaluationBase
|
||||||
|
|
||||||
|
|
||||||
class Recall(EvaluationBase):
|
class Recall(EvaluationBase):
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test_bleu_score"""
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
from mindspore.nn.metrics import BleuScore
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score():
|
||||||
|
"""test_bleu_score"""
|
||||||
|
candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']]
|
||||||
|
reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'],
|
||||||
|
['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]]
|
||||||
|
metric = BleuScore(n_gram=4, smooth=False)
|
||||||
|
metric.clear()
|
||||||
|
metric.update(candidate_corpus, reference_corpus)
|
||||||
|
bleu_score = metric.eval()
|
||||||
|
|
||||||
|
assert math.isclose(bleu_score, 0.5946035575013605, abs_tol=0.0001)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score_update1():
|
||||||
|
"""test_bleu_score_update1"""
|
||||||
|
candidate_corpus = ['the cat is on the mat'.split()]
|
||||||
|
metric = BleuScore()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
metric.update(candidate_corpus)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score_update2():
|
||||||
|
"""test_bleu_score_update2"""
|
||||||
|
candidate_corpus = [['the cat is on the mat'.split()], ['a cat is on the mat'.split()]]
|
||||||
|
reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||||
|
metric = BleuScore()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
metric.update(candidate_corpus, reference_corpus)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score_init1():
|
||||||
|
"""test_bleu_score_init1"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BleuScore(n_gram="3")
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score_init2():
|
||||||
|
"""test_bleu_score_init2"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BleuScore(smooth=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bleu_score_runtime():
|
||||||
|
"""test_bleu_score_runtime"""
|
||||||
|
metric = BleuScore()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
metric.eval()
|
|
@ -0,0 +1,95 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test cosine_similarity"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import pairwise
|
||||||
|
from mindspore.nn.metrics import CosineSimilarity
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity():
|
||||||
|
"""test_cosine_similarity"""
|
||||||
|
test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]])
|
||||||
|
metric = CosineSimilarity()
|
||||||
|
metric.clear()
|
||||||
|
metric.update(test_data)
|
||||||
|
square_matrix = metric.eval()
|
||||||
|
|
||||||
|
assert np.allclose(square_matrix, np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_compare():
|
||||||
|
"""test_cosine_similarity_compare"""
|
||||||
|
test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]])
|
||||||
|
metric = CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=False)
|
||||||
|
metric.clear()
|
||||||
|
metric.update(test_data)
|
||||||
|
ms_square_matrix = metric.eval()
|
||||||
|
|
||||||
|
def sklearn_cosine_similarity(test_data, similarity, reduction):
|
||||||
|
"""sklearn_cosine_similarity"""
|
||||||
|
metric_func = {'cosine': pairwise.cosine_similarity,
|
||||||
|
'dot': pairwise.linear_kernel}[similarity]
|
||||||
|
|
||||||
|
square_matrix = metric_func(test_data, test_data)
|
||||||
|
if reduction == 'mean':
|
||||||
|
return square_matrix.mean(axis=-1)
|
||||||
|
if reduction == 'sum':
|
||||||
|
return square_matrix.sum(axis=-1)
|
||||||
|
return square_matrix
|
||||||
|
|
||||||
|
sk_square_matrix = sklearn_cosine_similarity(test_data, similarity='cosine', reduction='none')
|
||||||
|
|
||||||
|
assert np.allclose(sk_square_matrix, ms_square_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_init1():
|
||||||
|
"""test_cosine_similarity_init1"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CosineSimilarity(similarity="4")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_init2():
|
||||||
|
"""test_cosine_similarity_init2"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
CosineSimilarity(similarity=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_init3():
|
||||||
|
"""test_cosine_similarity_init3"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
CosineSimilarity(reduction=2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_init4():
|
||||||
|
"""test_cosine_similarity_init4"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CosineSimilarity(reduction="1")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_init5():
|
||||||
|
"""test_cosine_similarity_init5"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
CosineSimilarity(zero_diagonal=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cosine_similarity_runtime():
|
||||||
|
"""test_cosine_similarity_runtime"""
|
||||||
|
metric = CosineSimilarity()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
metric.eval()
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test_occlusion_sensitivity"""
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.nn.metrics import OcclusionSensitivity
|
||||||
|
|
||||||
|
|
||||||
|
class DenseNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(DenseNet, self).__init__()
|
||||||
|
w = np.array([[0.1, 0.8, 0.1, 0.1], [1, 1, 1, 1]]).astype(np.float32)
|
||||||
|
b = np.array([0.3, 0.6]).astype(np.float32)
|
||||||
|
self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b))
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.dense(x)
|
||||||
|
|
||||||
|
|
||||||
|
model = DenseNet()
|
||||||
|
|
||||||
|
|
||||||
|
def test_occlusion_sensitivity():
|
||||||
|
"""test_occlusion_sensitivity"""
|
||||||
|
test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32)
|
||||||
|
label = np.array(1).astype(np.int32)
|
||||||
|
metric = OcclusionSensitivity()
|
||||||
|
metric.clear()
|
||||||
|
metric.update(model, test_data, label)
|
||||||
|
score = metric.eval()
|
||||||
|
|
||||||
|
assert np.allclose(score, np.array([0.2, 0.2, 0.2, 0.2]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_occlusion_sensitivity_update1():
|
||||||
|
"""test_occlusion_sensitivity_update1"""
|
||||||
|
test_data = np.array([[5, 8], [3, 2], [4, 2]])
|
||||||
|
metric = OcclusionSensitivity()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
metric.update(test_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_occlusion_sensitivity_init1():
|
||||||
|
"""test_occlusion_sensitivity_init1"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
OcclusionSensitivity(pad_val=False, margin=2, n_batch=128, b_box=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_occlusion_sensitivity_init2():
|
||||||
|
"""test_occlusion_sensitivity_init2"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
OcclusionSensitivity(pad_val=0.0, margin=True, n_batch=128, b_box=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_occlusion_sensitivity_runtime():
|
||||||
|
"""test_occlusion_sensitivity_runtime"""
|
||||||
|
metric = OcclusionSensitivity()
|
||||||
|
metric.clear()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
metric.eval()
|
Loading…
Reference in New Issue