From f6de97fc272ff99c14b3708c91cdf4dd4fb6b86f Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Tue, 22 Dec 2020 16:19:00 +0800 Subject: [PATCH] bleu_score --- mindspore/nn/metrics/__init__.py | 9 + mindspore/nn/metrics/_evaluation.py | 103 --------- mindspore/nn/metrics/accuracy.py | 2 +- mindspore/nn/metrics/bleu_score.py | 149 +++++++++++++ mindspore/nn/metrics/cosine_similarity.py | 96 +++++++++ mindspore/nn/metrics/metric.py | 86 ++++++++ mindspore/nn/metrics/occlusion_sensitivity.py | 196 ++++++++++++++++++ mindspore/nn/metrics/precision.py | 2 +- mindspore/nn/metrics/recall.py | 2 +- tests/ut/python/metrics/test_bleu_score.py | 73 +++++++ .../python/metrics/test_cosine_similarity.py | 95 +++++++++ .../metrics/test_occlusion_sensitivity.py | 77 +++++++ 12 files changed, 784 insertions(+), 106 deletions(-) delete mode 100644 mindspore/nn/metrics/_evaluation.py create mode 100644 mindspore/nn/metrics/bleu_score.py create mode 100644 mindspore/nn/metrics/cosine_similarity.py create mode 100644 mindspore/nn/metrics/occlusion_sensitivity.py create mode 100644 tests/ut/python/metrics/test_bleu_score.py create mode 100644 tests/ut/python/metrics/test_cosine_similarity.py create mode 100644 tests/ut/python/metrics/test_occlusion_sensitivity.py diff --git a/mindspore/nn/metrics/__init__.py b/mindspore/nn/metrics/__init__.py index cddfa35f96c..211c201ce87 100755 --- a/mindspore/nn/metrics/__init__.py +++ b/mindspore/nn/metrics/__init__.py @@ -32,6 +32,9 @@ from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5Categori from .loss import Loss from .mean_surface_distance import MeanSurfaceDistance from .root_mean_square_surface_distance import RootMeanSquareDistance +from .bleu_score import BleuScore +from .cosine_similarity import CosineSimilarity +from .occlusion_sensitivity import OcclusionSensitivity __all__ = [ "names", @@ -43,6 +46,9 @@ __all__ = [ "HausdorffDistance", "Recall", "Fbeta", + "BleuScore", + "CosineSimilarity", + "OcclusionSensitivity", "F1", "Dice", "ROC", @@ -64,6 +70,9 @@ __factory__ = { 'dice': Dice, 'roc': ROC, 'auc': auc, + 'bleu_score': BleuScore, + 'cosine_similarity': CosineSimilarity, + 'occlusion_sensitivity': OcclusionSensitivity, 'topk': TopKCategoricalAccuracy, 'hausdorff_distance': HausdorffDistance, 'top_1_accuracy': Top1CategoricalAccuracy, diff --git a/mindspore/nn/metrics/_evaluation.py b/mindspore/nn/metrics/_evaluation.py deleted file mode 100644 index d9c32bb162e..00000000000 --- a/mindspore/nn/metrics/_evaluation.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Evaluation.""" -import numpy as np -from .metric import Metric - -_eval_types = {'classification', 'multilabel'} - - -class EvaluationBase(Metric): - """ - Base class of evaluation. - - Note: - Please refer to the definition of class `Accuracy`. - - Args: - eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}. - - Raises: - TypeError: If the input type is not classification or multilabel. - """ - def __init__(self, eval_type): - super(EvaluationBase, self).__init__() - if eval_type not in _eval_types: - raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type)) - self._type = eval_type - - def _check_shape(self, y_pred, y): - """ - Checks the shapes of y_pred and y. - - Args: - y_pred (Tensor): Predict array. - y (Tensor): Target array. - """ - if self._type == 'classification': - if y_pred.ndim != y.ndim + 1: - raise ValueError('Classification case, dims of y_pred equal dims of y add 1, ' - 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim)) - if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]: - raise ValueError('Classification case, y_pred shape and y shape can not match. ' - 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape)) - else: - if y_pred.ndim != y.ndim: - raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} ' - 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim)) - if y_pred.shape != y.shape: - raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'. - format(self._type, y_pred.shape, y.shape)) - - def _check_value(self, y_pred, y): - """ - Checks the values of y_pred and y. - - Args: - y_pred (Tensor): Predict array. - y (Tensor): Target array. - """ - if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()): - raise ValueError('For multilabel case, input value must be 1 or 0.') - - def clear(self): - """ - A interface describes the behavior of clearing the internal evaluation result. - - Note: - All subclasses must override this interface. - """ - raise NotImplementedError - - def update(self, *inputs): - """ - A interface describes the behavior of updating the internal evaluation result. - - Note: - All subclasses must override this interface. - - Args: - inputs: The first item is predicted array and the second item is target array. - """ - raise NotImplementedError - - def eval(self): - """ - A interface describes the behavior of computing the evaluation result. - - Note: - All subclasses must override this interface. - """ - raise NotImplementedError diff --git a/mindspore/nn/metrics/accuracy.py b/mindspore/nn/metrics/accuracy.py index 46652759ec1..7fc6af6d3ef 100644 --- a/mindspore/nn/metrics/accuracy.py +++ b/mindspore/nn/metrics/accuracy.py @@ -14,7 +14,7 @@ # ============================================================================ """Accuracy.""" import numpy as np -from ._evaluation import EvaluationBase +from .metric import EvaluationBase class Accuracy(EvaluationBase): diff --git a/mindspore/nn/metrics/bleu_score.py b/mindspore/nn/metrics/bleu_score.py new file mode 100644 index 00000000000..58b8b381b8b --- /dev/null +++ b/mindspore/nn/metrics/bleu_score.py @@ -0,0 +1,149 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""BleuScore.""" +from collections import Counter +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class BleuScore(Metric): + """ + Calculate BLEU score of machine translated text with one or more references. + + Args: + n_gram (int): The n_gram value ranged from 1 to 4. Default: 4 + smooth (bool): Whether or not to apply smoothing. Default: False + + Example: + >>> candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']] + >>> reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'], + >>> ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]] + >>> metric = BleuScore() + >>> metric.clear() + >>> metric.update(candidate_corpus, reference_corpus) + >>> bleu_score = metric.eval() + 0.5946035575013605 + """ + def __init__(self, n_gram=4, smooth=False): + super().__init__() + self.n_gram = validator.check_value_type("n_gram", n_gram, [int]) + if self.n_gram > 4 or self.n_gram < 1: + raise ValueError('The n_gram value ranged from 1 to 4, but got {}'.format(n_gram)) + + self.smooth = validator.check_value_type("smooth", smooth, [bool]) + self.clear() + + def clear(self): + """Clear the internal evaluation result.""" + self._numerator = np.zeros(self.n_gram) + self._denominator = np.zeros(self.n_gram) + self._precision_scores = np.zeros(self.n_gram) + self._c = 0.0 + self._r = 0.0 + self._trans_len = 0 + self._ref_len = 0 + self._is_update = False + + def _count_ngram(self, ngram_input_list, n_gram): + """ + Counting how many times each word appears in a given text with ngram. + + Args: + ngram_input_list (list): A list of translated text or reference texts. + n_gram (int): gram value ranged 1 to 4. + + Return: + ngram_counter: a collections.Counter object of ngram. + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j:(i + j)]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + def update(self, *inputs): + """ + Updates the internal evaluation result with `candidate_corpus` and `reference_corpus`. + + Args: + inputs: Input `candidate_corpus` and ``reference_corpus`. `candidate_corpus` and `reference_corpus` are a + list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is + an iterable of iterables of reference corpus. + + Raises: + ValueError: If the number of input is not 2. + """ + if len(inputs) != 2: + raise ValueError('The bleu_score need 2 inputs (candidate_corpus, reference_corpus), ' + 'but got {}'.format(len(inputs))) + candidate_corpus = inputs[0] + reference_corpus = inputs[1] + if len(candidate_corpus) != len(reference_corpus): + raise ValueError('translate_corpus and reference_corpus should be equal in length, ' + 'but got {} {}'.format(len(candidate_corpus), len(reference_corpus))) + + for (candidate, references) in zip(candidate_corpus, reference_corpus): + self._c += len(candidate) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(candidate) - x) for x in ref_len_list] + self._r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = self._count_ngram(candidate, self.n_gram) + reference_counter = Counter() + + for ref in references: + reference_counter |= self._count_ngram(ref, self.n_gram) + + ngram_counter_clip = translation_counter & reference_counter + + for counter_clip in ngram_counter_clip: + self._numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + self._denominator[len(counter) - 1] += translation_counter[counter] + + self._trans_len = np.array(self._c) + self._ref_len = np.array(self._r) + self._is_update = True + + def eval(self): + """ + Computes the bleu score. + + Returns: + A numpy with bleu score. + + """ + if self._is_update is False: + raise RuntimeError('Call the update method before calling eval.') + if min(self._numerator) == 0.0: + return np.array(0.0) + + if self.smooth: + precision_scores = np.add(self._numerator, np.ones(self.n_gram)) / np.add(self._denominator, + np.ones(self.n_gram)) + else: + precision_scores = self._numerator / self._denominator + + log_precision_scores = np.array([1.0 / self.n_gram] * self.n_gram) * np.log(precision_scores) + geometric_mean = np.exp(np.sum(log_precision_scores)) + brevity_penalty = np.array(1.0) if self._c > self._r else np.exp(1 - (self._ref_len / self._trans_len)) + bleu = brevity_penalty * geometric_mean + + return bleu diff --git a/mindspore/nn/metrics/cosine_similarity.py b/mindspore/nn/metrics/cosine_similarity.py new file mode 100644 index 00000000000..71af0ecfff0 --- /dev/null +++ b/mindspore/nn/metrics/cosine_similarity.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CosineSimilarity.""" +import numpy as np +from mindspore._checkparam import Validator as validator +from .metric import Metric + + +class CosineSimilarity(Metric): + """ + Computes representation similarity + + Args: + similarity (str): 'dot' or 'cosine'. Default: 'cosine' + reduction (str): 'none', 'sum', 'mean' (all along dim -1). Default: 'none' + zero_diagonal (bool): if True, the diagonals are set to zero. Default: True + + Return: + A square matrix (input1, input1) with the similarity scores between all elements. + If sum or mean are used, then returns (b, 1) with the reduced value for each row + + Example: + >>> test_data = np.random.randn(4, 8) + >>> metric = CosineSimilarity() + >>> metric.clear() + >>> metric.update(test_data) + >>> square_matrix = metric.eval() + [[0. -0.14682831 0.19102288 -0.36204537] + ... + ] + """ + def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True): + super().__init__() + similarity_list = ['dot', 'cosine'] + reduction_list = ['none', 'sum', 'mean'] + similarity = validator.check_value_type("similarity", similarity, [str]) + self.similarity = validator.check_string(similarity, similarity_list, "similarity") + reduction = validator.check_value_type("reduction", reduction, [str]) + self.reduction = validator.check_string(reduction, reduction_list, "reduction") + self.zero_diagonal = validator.check_value_type("zero_diagonal", zero_diagonal, [bool]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self.sqr_mtx_res = 0 + self._is_update = False + + def update(self, *inputs): + """ + Updates the internal evaluation result with 'input1'. + + Args: + inputs: input_data `input1`. The input_data is a `Tensor`or an array. + """ + input_data = self._convert_data(inputs[0]) + + if self.similarity == 'cosine': + data = np.linalg.norm(input_data, ord=2, axis=1) + input_data = input_data / np.expand_dims(data, 1) + + self.sqr_mtx_res = np.dot(input_data, input_data.transpose(1, 0)) + self._is_update = True + + def eval(self): + """ + Computes the Cosine_Similarity square matrix. + + Returns: + A square matrix. + + """ + if not self._is_update: + raise RuntimeError('Call the update method before calling eval.') + + if self.zero_diagonal: + np.fill_diagonal(self.sqr_mtx_res, 0) + + if self.reduction == 'mean': + self.sqr_mtx_res = np.mean(self.sqr_mtx_res, axis=-1) + + if self.reduction == 'sum': + self.sqr_mtx_res = np.sum(self.sqr_mtx_res, axis=-1) + + return self.sqr_mtx_res diff --git a/mindspore/nn/metrics/metric.py b/mindspore/nn/metrics/metric.py index 13e1775e53d..43269e8757e 100644 --- a/mindspore/nn/metrics/metric.py +++ b/mindspore/nn/metrics/metric.py @@ -17,6 +17,8 @@ from abc import ABCMeta, abstractmethod import numpy as np from mindspore.common.tensor import Tensor +_eval_types = {'classification', 'multilabel'} + class Metric(metaclass=ABCMeta): """ @@ -140,3 +142,87 @@ class Metric(metaclass=ABCMeta): inputs: A variable-length input argument list. """ raise NotImplementedError('Must define update function to use this base class') + + +class EvaluationBase(Metric): + """ + Base class of evaluation. + + Note: + Please refer to the definition of class `Accuracy`. + + Args: + eval_type (str): Type of evaluation must be in {'classification', 'multilabel'}. + + Raises: + TypeError: If the input type is not classification or multilabel. + """ + def __init__(self, eval_type): + super(EvaluationBase, self).__init__() + if eval_type not in _eval_types: + raise TypeError('Type must be in {}, but got {}'.format(_eval_types, eval_type)) + self._type = eval_type + + def _check_shape(self, y_pred, y): + """ + Checks the shapes of y_pred and y. + + Args: + y_pred (Tensor): Predict array. + y (Tensor): Target array. + """ + if self._type == 'classification': + if y_pred.ndim != y.ndim + 1: + raise ValueError('Classification case, dims of y_pred equal dims of y add 1, ' + 'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim)) + if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]: + raise ValueError('Classification case, y_pred shape and y shape can not match. ' + 'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape)) + else: + if y_pred.ndim != y.ndim: + raise ValueError('{} case, dims of y_pred need equal with dims of y, but got y_pred: {} ' + 'dims and y: {} dims.'.format(self._type, y_pred.ndim, y.ndim)) + if y_pred.shape != y.shape: + raise ValueError('{} case, y_pred shape need equal with y shape, but got y_pred: {} and y: {}'. + format(self._type, y_pred.shape, y.shape)) + + def _check_value(self, y_pred, y): + """ + Checks the values of y_pred and y. + + Args: + y_pred (Tensor): Predict array. + y (Tensor): Target array. + """ + if self._type != 'classification' and not (np.equal(y_pred ** 2, y_pred).all() and np.equal(y ** 2, y).all()): + raise ValueError('For multilabel case, input value must be 1 or 0.') + + def clear(self): + """ + A interface describes the behavior of clearing the internal evaluation result. + + Note: + All subclasses must override this interface. + """ + raise NotImplementedError + + def update(self, *inputs): + """ + A interface describes the behavior of updating the internal evaluation result. + + Note: + All subclasses must override this interface. + + Args: + inputs: The first item is predicted array and the second item is target array. + """ + raise NotImplementedError + + def eval(self): + """ + A interface describes the behavior of computing the evaluation result. + + Note: + All subclasses must override this interface. + """ + raise NotImplementedError diff --git a/mindspore/nn/metrics/occlusion_sensitivity.py b/mindspore/nn/metrics/occlusion_sensitivity.py new file mode 100644 index 00000000000..e69e749bd14 --- /dev/null +++ b/mindspore/nn/metrics/occlusion_sensitivity.py @@ -0,0 +1,196 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""OcclusionSensitivity.""" +from collections.abc import Sequence +import numpy as np +from mindspore import nn +from mindspore.common.tensor import Tensor +from mindspore._checkparam import Validator as validator +from .metric import Metric + +try: + from tqdm import trange +except (ImportError, AttributeError): + trange = range + + +class OcclusionSensitivity(Metric): + """ + This function is used to calculate the occlusion sensitivity of the model for a given image. + Occlusion sensitivity refers to how the probability of a given prediction changes with the change of the occluded + part of the image. + + For a given result, the output probability is the probability of a region. + + The higher the value in the output image, the greater the decline of certainty, indicating that + the occluded area is more important in the decision-making process. + + Args: + pad_val (float): What values need to be entered in the image when a part of the image is occluded. Default: 0.0. + margin (Union[int, Sequence]): Create a cuboid / cube around the voxel you want to occlude. Default: 2. + n_batch (int): number of images in a batch before inference. Default: 128. + b_box (Sequence): Bounding box on which to perform the analysis. The output image will also match in size. + There should be a minimum and maximum for all dimensions except batch: + ``[min1, max1, min2, max2,...]``. If no bounding box is supplied, this will be the same size + as the input image. If a bounding box is used, the output image will be cropped to this size. + Default: None. + + Example: + >>> class DenseNet(nn.Cell): + >>> def init(self): + >>> super(DenseNet, self).init() + >>> w = np.array([[0.1, 0.8, 0.1, 0.1],[1, 1, 1, 1]]).astype(np.float32) + >>> b = np.array([0.3, 0.6]).astype(np.float32) + >>> self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b)) + >>> + >>> def construct(self, x): + >>> return self.dense(x) + >>> + >>> model = DenseNet() + >>> test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32) + >>> label = np.array(1).astype(np.int32) + >>> metric = OcclusionSensitivity() + >>> metric.clear() + >>> metric.update(model, test_data, label) + >>> score = metric.eval() + [0.29999995 0.6 1 0.9] + """ + def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None): + super().__init__() + self.pad_val = validator.check_value_type("pad_val", pad_val, [float]) + self.margin = validator.check_value_type("margin", margin, [int, Sequence]) + self.n_batch = validator.check_value_type("n_batch", n_batch, [int]) + self.b_box = b_box if b_box is None else validator.check_value_type("b_box", b_box, [list]) + self.clear() + + def clear(self): + """Clears the internal evaluation result.""" + self._baseline = 0 + self._sensitivity_im = 0 + self._is_update = False + + def _check_input_bounding_box(self, b_box, im_shape): + """Check that the bounding box (if supplied) is as expected.""" + # If no bounding box has been supplied, set min and max to None + if b_box is None: + b_box_min = b_box_max = None + else: + if len(b_box) != 2 * len(im_shape): + raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") + + b_box_min = np.array(b_box[::2]) + b_box_max = np.array(b_box[1::2]) + b_box_min[b_box_min < 0] = 0 + b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 + if np.any(b_box_max >= im_shape): + raise ValueError("Max bounding box should be < image size for all values") + if np.any(b_box_min > b_box_max): + raise ValueError("Min bounding box should be <= max for all values") + + return b_box_min, b_box_max + + def _append_to_sensitivity_im(self, model, batch_images, batch_ids, sensitivity_im): + """For a given number of images, the probability of predicting a given label is obtained. Attach to previous + assessment.""" + batch_images = np.vstack(batch_images) + batch_ids = np.expand_dims(batch_ids, 1) + model_numpy = model(Tensor(batch_images)).asnumpy() + first_indices = np.arange(batch_ids.shape[0])[:, None] + scores = model_numpy[first_indices, batch_ids] + if sensitivity_im.size == 0: + return np.vstack(scores) + return np.vstack((sensitivity_im, scores)) + + def update(self, *inputs): + """ + Updates input, including `model`, `y_pred` and `label`. + + Inputs: + - **model** (nn.Cell) - classification model to use for inference. + - **y_pred** (Union[Tensor, list, np.ndarray]) - image to test. Should be tensor consisting of 1 batch, + can be 2- or 3D. + - **label** (Union[int, Tensor]) - classification label to check for changes (normally the true label, + but doesn't have to be + + Raises: + ValueError: If the number of input is not 3. + """ + if len(inputs) != 3: + raise ValueError('occlusion_sensitivity need 3 inputs (model, y_pred, y), but got {}'.format(len(inputs))) + + model = inputs[0] + y_pred = self._convert_data(inputs[1]) + label = self._convert_data(inputs[2]) + model = validator.check_value_type("model", model, [nn.Cell]) + + if y_pred.shape[0] > 1: + raise RuntimeError("Expected batch size of 1.") + + if isinstance(label, int): + label = np.array([[label]], dtype=int) + # If the label is a tensor, make sure there's only 1 element + elif np.prod(label.shape) != y_pred.shape[0]: + raise RuntimeError("Expected as many labels as batches.") + + y_pred_shape = np.array(y_pred.shape[1:]) + b_box_min, b_box_max = self._check_input_bounding_box(self.b_box, y_pred_shape) + + temp = model(Tensor(y_pred)).asnumpy() + self._baseline = temp[0, label].item() + + batch_images = [] + batch_ids = [] + + sensitivity_im = np.empty(0, dtype=float) + + output_im_shape = y_pred_shape if self.b_box is None else b_box_max - b_box_min + 1 + num_required_predictions = np.prod(output_im_shape) + + for i in trange(num_required_predictions): + idx = np.unravel_index(i, output_im_shape) + if b_box_min is not None: + idx += b_box_min + + min_idx = [max(0, i - self.margin) for i in idx] + max_idx = [min(j, i + self.margin) for i, j in zip(idx, y_pred_shape)] + + occlu_im = y_pred.copy() + occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val + + batch_images.append(occlu_im) + batch_ids.append(label) + + if len(batch_images) == self.n_batch or i == num_required_predictions - 1: + sensitivity_im = self._append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) + batch_images = [] + batch_ids = [] + + self._sensitivity_im = sensitivity_im.reshape(output_im_shape) + self._is_update = True + + def eval(self): + """ + Computes the occlusion_sensitivity. + + Returns: + A numpy ndarray. + + """ + if not self._is_update: + raise RuntimeError('Call the update method before calling eval.') + + sensitivity = self._baseline - np.squeeze(self._sensitivity_im) + + return sensitivity diff --git a/mindspore/nn/metrics/precision.py b/mindspore/nn/metrics/precision.py index a0a4c727d70..46f7006e4bb 100644 --- a/mindspore/nn/metrics/precision.py +++ b/mindspore/nn/metrics/precision.py @@ -18,7 +18,7 @@ import sys import numpy as np from mindspore._checkparam import Validator as validator -from ._evaluation import EvaluationBase +from .metric import EvaluationBase class Precision(EvaluationBase): diff --git a/mindspore/nn/metrics/recall.py b/mindspore/nn/metrics/recall.py index 2ee6b5db845..c105f401988 100644 --- a/mindspore/nn/metrics/recall.py +++ b/mindspore/nn/metrics/recall.py @@ -18,7 +18,7 @@ import sys import numpy as np from mindspore._checkparam import Validator as validator -from ._evaluation import EvaluationBase +from .metric import EvaluationBase class Recall(EvaluationBase): diff --git a/tests/ut/python/metrics/test_bleu_score.py b/tests/ut/python/metrics/test_bleu_score.py new file mode 100644 index 00000000000..f5a7f02375a --- /dev/null +++ b/tests/ut/python/metrics/test_bleu_score.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test_bleu_score""" +import math +import pytest +from mindspore.nn.metrics import BleuScore + + +def test_bleu_score(): + """test_bleu_score""" + candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']] + reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'], + ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]] + metric = BleuScore(n_gram=4, smooth=False) + metric.clear() + metric.update(candidate_corpus, reference_corpus) + bleu_score = metric.eval() + + assert math.isclose(bleu_score, 0.5946035575013605, abs_tol=0.0001) + + +def test_bleu_score_update1(): + """test_bleu_score_update1""" + candidate_corpus = ['the cat is on the mat'.split()] + metric = BleuScore() + metric.clear() + + with pytest.raises(ValueError): + metric.update(candidate_corpus) + + +def test_bleu_score_update2(): + """test_bleu_score_update2""" + candidate_corpus = [['the cat is on the mat'.split()], ['a cat is on the mat'.split()]] + reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + metric = BleuScore() + metric.clear() + + with pytest.raises(ValueError): + metric.update(candidate_corpus, reference_corpus) + + +def test_bleu_score_init1(): + """test_bleu_score_init1""" + with pytest.raises(TypeError): + BleuScore(n_gram="3") + + +def test_bleu_score_init2(): + """test_bleu_score_init2""" + with pytest.raises(TypeError): + BleuScore(smooth=5) + + +def test_bleu_score_runtime(): + """test_bleu_score_runtime""" + metric = BleuScore() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval() diff --git a/tests/ut/python/metrics/test_cosine_similarity.py b/tests/ut/python/metrics/test_cosine_similarity.py new file mode 100644 index 00000000000..794d351df84 --- /dev/null +++ b/tests/ut/python/metrics/test_cosine_similarity.py @@ -0,0 +1,95 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cosine_similarity""" +import pytest +import numpy as np +from sklearn.metrics import pairwise +from mindspore.nn.metrics import CosineSimilarity + + +def test_cosine_similarity(): + """test_cosine_similarity""" + test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) + metric = CosineSimilarity() + metric.clear() + metric.update(test_data) + square_matrix = metric.eval() + + assert np.allclose(square_matrix, np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]])) + + +def test_cosine_similarity_compare(): + """test_cosine_similarity_compare""" + test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) + metric = CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=False) + metric.clear() + metric.update(test_data) + ms_square_matrix = metric.eval() + + def sklearn_cosine_similarity(test_data, similarity, reduction): + """sklearn_cosine_similarity""" + metric_func = {'cosine': pairwise.cosine_similarity, + 'dot': pairwise.linear_kernel}[similarity] + + square_matrix = metric_func(test_data, test_data) + if reduction == 'mean': + return square_matrix.mean(axis=-1) + if reduction == 'sum': + return square_matrix.sum(axis=-1) + return square_matrix + + sk_square_matrix = sklearn_cosine_similarity(test_data, similarity='cosine', reduction='none') + + assert np.allclose(sk_square_matrix, ms_square_matrix) + + +def test_cosine_similarity_init1(): + """test_cosine_similarity_init1""" + with pytest.raises(ValueError): + CosineSimilarity(similarity="4") + + +def test_cosine_similarity_init2(): + """test_cosine_similarity_init2""" + with pytest.raises(TypeError): + CosineSimilarity(similarity=4) + + +def test_cosine_similarity_init3(): + """test_cosine_similarity_init3""" + with pytest.raises(TypeError): + CosineSimilarity(reduction=2) + + +def test_cosine_similarity_init4(): + """test_cosine_similarity_init4""" + with pytest.raises(ValueError): + CosineSimilarity(reduction="1") + + + +def test_cosine_similarity_init5(): + """test_cosine_similarity_init5""" + with pytest.raises(TypeError): + CosineSimilarity(zero_diagonal=3) + + +def test_cosine_similarity_runtime(): + """test_cosine_similarity_runtime""" + metric = CosineSimilarity() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval() diff --git a/tests/ut/python/metrics/test_occlusion_sensitivity.py b/tests/ut/python/metrics/test_occlusion_sensitivity.py new file mode 100644 index 00000000000..3c1c86b8153 --- /dev/null +++ b/tests/ut/python/metrics/test_occlusion_sensitivity.py @@ -0,0 +1,77 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test_occlusion_sensitivity""" +import pytest +import numpy as np +from mindspore import nn +from mindspore.common.tensor import Tensor +from mindspore.nn.metrics import OcclusionSensitivity + + +class DenseNet(nn.Cell): + def __init__(self): + super(DenseNet, self).__init__() + w = np.array([[0.1, 0.8, 0.1, 0.1], [1, 1, 1, 1]]).astype(np.float32) + b = np.array([0.3, 0.6]).astype(np.float32) + self.dense = nn.Dense(4, 2, weight_init=Tensor(w), bias_init=Tensor(b)) + + def construct(self, x): + return self.dense(x) + + +model = DenseNet() + + +def test_occlusion_sensitivity(): + """test_occlusion_sensitivity""" + test_data = np.array([[0.1, 0.2, 0.3, 0.4]]).astype(np.float32) + label = np.array(1).astype(np.int32) + metric = OcclusionSensitivity() + metric.clear() + metric.update(model, test_data, label) + score = metric.eval() + + assert np.allclose(score, np.array([0.2, 0.2, 0.2, 0.2])) + + +def test_occlusion_sensitivity_update1(): + """test_occlusion_sensitivity_update1""" + test_data = np.array([[5, 8], [3, 2], [4, 2]]) + metric = OcclusionSensitivity() + metric.clear() + + with pytest.raises(ValueError): + metric.update(test_data) + + +def test_occlusion_sensitivity_init1(): + """test_occlusion_sensitivity_init1""" + with pytest.raises(TypeError): + OcclusionSensitivity(pad_val=False, margin=2, n_batch=128, b_box=None) + + +def test_occlusion_sensitivity_init2(): + """test_occlusion_sensitivity_init2""" + with pytest.raises(TypeError): + OcclusionSensitivity(pad_val=0.0, margin=True, n_batch=128, b_box=None) + + +def test_occlusion_sensitivity_runtime(): + """test_occlusion_sensitivity_runtime""" + metric = OcclusionSensitivity() + metric.clear() + + with pytest.raises(RuntimeError): + metric.eval()