forked from mindspore-Ecosystem/mindspore
!17373 add docs, user-example and assertion for rearrange_inputs
Merge pull request !17373 from zhiqwang/rearrange-inputs-docs
This commit is contained in:
commit
c293f06b27
|
@ -23,7 +23,30 @@ _eval_types = {'classification', 'multilabel'}
|
|||
|
||||
def rearrange_inputs(func):
|
||||
"""
|
||||
This decorator is used to rearrange the inputs according to its indexes.
|
||||
This decorator is used to rearrange the inputs according to its `_indexes` attributes
|
||||
which is specified by the `set_indexes` method.
|
||||
|
||||
Examples:
|
||||
>>> class RearrangeInputsExample:
|
||||
... def __init__(self):
|
||||
... self._indexes = None
|
||||
...
|
||||
... @property
|
||||
... def indexes(self):
|
||||
... return getattr(self, '_indexes', None)
|
||||
...
|
||||
... def set_indexes(self, indexes):
|
||||
... self._indexes = indexes
|
||||
... return self
|
||||
...
|
||||
... @rearrange_inputs
|
||||
... def update(self, *inputs):
|
||||
... return inputs
|
||||
>>>
|
||||
>>> rearrange_inputs_example = RearrangeInputsExample().set_indexes([1, 0])
|
||||
>>> outs = rearrange_inputs_example.update(5, 9)
|
||||
>>> print(outs)
|
||||
>>> (9, 5)
|
||||
|
||||
Args:
|
||||
func (Callable): A candidate function to be wrapped whose input will be rearranged.
|
||||
|
@ -43,9 +66,8 @@ class Metric(metaclass=ABCMeta):
|
|||
"""
|
||||
Base class of metric.
|
||||
|
||||
|
||||
Note:
|
||||
For examples of subclasses, please refer to the definition of class `MAE`, 'Recall' etc.
|
||||
For examples of subclasses, please refer to the definition of class `MAE`, `Recall` etc.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._indexes = None
|
||||
|
@ -117,9 +139,42 @@ class Metric(metaclass=ABCMeta):
|
|||
|
||||
@property
|
||||
def indexes(self):
|
||||
"""The `_indexes` is a private attributes, and you can retrieve it by `self.indexes`.
|
||||
"""
|
||||
return getattr(self, '_indexes', None)
|
||||
|
||||
def set_indexes(self, indexes):
|
||||
"""
|
||||
The `_indexes` is a private attributes, and you can modify it by this function.
|
||||
This allows you to determine the order of logits and labels to be calculated in the
|
||||
inputs, specially when you call the method `update` within this metrics.
|
||||
|
||||
Note:
|
||||
It has been applied in subclass of Metric, eg. `Accuracy`, `BleuScore`, `ConfusionMatrix`,
|
||||
`CosineSimilarity`, `MAE`, and `MSE`.
|
||||
|
||||
Args:
|
||||
indexes (List(int)): The order of logits and labels to be rearranged.
|
||||
|
||||
Outputs:
|
||||
:class:`Metric`, its original Class instance.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.nn import Accuracy
|
||||
>>>
|
||||
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
|
||||
>>> y = Tensor(np.array([1, 0, 1]))
|
||||
>>> y2 = Tensor(np.array([0, 0, 1]))
|
||||
>>> metric = Accuracy('classification').set_indexes([0, 2])
|
||||
>>> metric.clear()
|
||||
>>> metric.update(x, y, y2)
|
||||
>>> accuracy = metric.eval()
|
||||
>>> print(accuracy)
|
||||
0.3333333333333333
|
||||
"""
|
||||
if not isinstance(indexes, list) or not all(isinstance(i, int) for i in indexes):
|
||||
raise ValueError("The indexes should be a list and all its elements should be int")
|
||||
self._indexes = indexes
|
||||
return self
|
||||
|
||||
|
|
|
@ -47,6 +47,12 @@ def test_classification_accuracy_indexes_awareness():
|
|||
assert math.isclose(accuracy, 1 / 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('indexes', [0, [0., 2.], [0., 1], ['1', '0']])
|
||||
def test_set_indexes(indexes):
|
||||
with pytest.raises(ValueError, match="indexes should be a list and all its elements should be int"):
|
||||
_ = Accuracy('classification').set_indexes(indexes)
|
||||
|
||||
|
||||
def test_multilabel_accuracy():
|
||||
x = Tensor(np.array([[0, 1, 0, 1], [1, 0, 1, 1], [0, 0, 0, 1]]))
|
||||
y = Tensor(np.array([[0, 1, 1, 1], [0, 1, 1, 1], [0, 0, 0, 1]]))
|
||||
|
|
Loading…
Reference in New Issue