!17373 add docs, user-example and assertion for rearrange_inputs

Merge pull request !17373 from zhiqwang/rearrange-inputs-docs
This commit is contained in:
i-robot 2021-06-10 16:17:23 +08:00 committed by Gitee
commit c293f06b27
2 changed files with 64 additions and 3 deletions

View File

@ -23,7 +23,30 @@ _eval_types = {'classification', 'multilabel'}
def rearrange_inputs(func):
"""
This decorator is used to rearrange the inputs according to its indexes.
This decorator is used to rearrange the inputs according to its `_indexes` attributes
which is specified by the `set_indexes` method.
Examples:
>>> class RearrangeInputsExample:
... def __init__(self):
... self._indexes = None
...
... @property
... def indexes(self):
... return getattr(self, '_indexes', None)
...
... def set_indexes(self, indexes):
... self._indexes = indexes
... return self
...
... @rearrange_inputs
... def update(self, *inputs):
... return inputs
>>>
>>> rearrange_inputs_example = RearrangeInputsExample().set_indexes([1, 0])
>>> outs = rearrange_inputs_example.update(5, 9)
>>> print(outs)
>>> (9, 5)
Args:
func (Callable): A candidate function to be wrapped whose input will be rearranged.
@ -43,9 +66,8 @@ class Metric(metaclass=ABCMeta):
"""
Base class of metric.
Note:
For examples of subclasses, please refer to the definition of class `MAE`, 'Recall' etc.
For examples of subclasses, please refer to the definition of class `MAE`, `Recall` etc.
"""
def __init__(self):
self._indexes = None
@ -117,9 +139,42 @@ class Metric(metaclass=ABCMeta):
@property
def indexes(self):
"""The `_indexes` is a private attributes, and you can retrieve it by `self.indexes`.
"""
return getattr(self, '_indexes', None)
def set_indexes(self, indexes):
"""
The `_indexes` is a private attributes, and you can modify it by this function.
This allows you to determine the order of logits and labels to be calculated in the
inputs, specially when you call the method `update` within this metrics.
Note:
It has been applied in subclass of Metric, eg. `Accuracy`, `BleuScore`, `ConfusionMatrix`,
`CosineSimilarity`, `MAE`, and `MSE`.
Args:
indexes (List(int)): The order of logits and labels to be rearranged.
Outputs:
:class:`Metric`, its original Class instance.
Examples:
>>> from mindspore import Tensor
>>> from mindspore.nn import Accuracy
>>>
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1]))
>>> y2 = Tensor(np.array([0, 0, 1]))
>>> metric = Accuracy('classification').set_indexes([0, 2])
>>> metric.clear()
>>> metric.update(x, y, y2)
>>> accuracy = metric.eval()
>>> print(accuracy)
0.3333333333333333
"""
if not isinstance(indexes, list) or not all(isinstance(i, int) for i in indexes):
raise ValueError("The indexes should be a list and all its elements should be int")
self._indexes = indexes
return self

View File

@ -47,6 +47,12 @@ def test_classification_accuracy_indexes_awareness():
assert math.isclose(accuracy, 1 / 3)
@pytest.mark.parametrize('indexes', [0, [0., 2.], [0., 1], ['1', '0']])
def test_set_indexes(indexes):
with pytest.raises(ValueError, match="indexes should be a list and all its elements should be int"):
_ = Accuracy('classification').set_indexes(indexes)
def test_multilabel_accuracy():
x = Tensor(np.array([[0, 1, 0, 1], [1, 0, 1, 1], [0, 0, 0, 1]]))
y = Tensor(np.array([[0, 1, 1, 1], [0, 1, 1, 1], [0, 0, 0, 1]]))