diff --git a/mindspore/nn/metrics/metric.py b/mindspore/nn/metrics/metric.py index 39a188b8b24..8bb7242cf94 100644 --- a/mindspore/nn/metrics/metric.py +++ b/mindspore/nn/metrics/metric.py @@ -23,7 +23,30 @@ _eval_types = {'classification', 'multilabel'} def rearrange_inputs(func): """ - This decorator is used to rearrange the inputs according to its indexes. + This decorator is used to rearrange the inputs according to its `_indexes` attributes + which is specified by the `set_indexes` method. + + Examples: + >>> class RearrangeInputsExample: + ... def __init__(self): + ... self._indexes = None + ... + ... @property + ... def indexes(self): + ... return getattr(self, '_indexes', None) + ... + ... def set_indexes(self, indexes): + ... self._indexes = indexes + ... return self + ... + ... @rearrange_inputs + ... def update(self, *inputs): + ... return inputs + >>> + >>> rearrange_inputs_example = RearrangeInputsExample().set_indexes([1, 0]) + >>> outs = rearrange_inputs_example.update(5, 9) + >>> print(outs) + >>> (9, 5) Args: func (Callable): A candidate function to be wrapped whose input will be rearranged. @@ -43,9 +66,8 @@ class Metric(metaclass=ABCMeta): """ Base class of metric. - Note: - For examples of subclasses, please refer to the definition of class `MAE`, 'Recall' etc. + For examples of subclasses, please refer to the definition of class `MAE`, `Recall` etc. """ def __init__(self): self._indexes = None @@ -117,9 +139,42 @@ class Metric(metaclass=ABCMeta): @property def indexes(self): + """The `_indexes` is a private attributes, and you can retrieve it by `self.indexes`. + """ return getattr(self, '_indexes', None) def set_indexes(self, indexes): + """ + The `_indexes` is a private attributes, and you can modify it by this function. + This allows you to determine the order of logits and labels to be calculated in the + inputs, specially when you call the method `update` within this metrics. + + Note: + It has been applied in subclass of Metric, eg. `Accuracy`, `BleuScore`, `ConfusionMatrix`, + `CosineSimilarity`, `MAE`, and `MSE`. + + Args: + indexes (List(int)): The order of logits and labels to be rearranged. + + Outputs: + :class:`Metric`, its original Class instance. + + Examples: + >>> from mindspore import Tensor + >>> from mindspore.nn import Accuracy + >>> + >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) + >>> y = Tensor(np.array([1, 0, 1])) + >>> y2 = Tensor(np.array([0, 0, 1])) + >>> metric = Accuracy('classification').set_indexes([0, 2]) + >>> metric.clear() + >>> metric.update(x, y, y2) + >>> accuracy = metric.eval() + >>> print(accuracy) + 0.3333333333333333 + """ + if not isinstance(indexes, list) or not all(isinstance(i, int) for i in indexes): + raise ValueError("The indexes should be a list and all its elements should be int") self._indexes = indexes return self diff --git a/tests/ut/python/metrics/test_accuracy.py b/tests/ut/python/metrics/test_accuracy.py index f3cc8816e19..bb4aa1a46dc 100644 --- a/tests/ut/python/metrics/test_accuracy.py +++ b/tests/ut/python/metrics/test_accuracy.py @@ -47,6 +47,12 @@ def test_classification_accuracy_indexes_awareness(): assert math.isclose(accuracy, 1 / 3) +@pytest.mark.parametrize('indexes', [0, [0., 2.], [0., 1], ['1', '0']]) +def test_set_indexes(indexes): + with pytest.raises(ValueError, match="indexes should be a list and all its elements should be int"): + _ = Accuracy('classification').set_indexes(indexes) + + def test_multilabel_accuracy(): x = Tensor(np.array([[0, 1, 0, 1], [1, 0, 1, 1], [0, 0, 0, 1]])) y = Tensor(np.array([[0, 1, 1, 1], [0, 1, 1, 1], [0, 0, 0, 1]]))