!15075 add print function in exmple

From: @lijiaqi0612
Reviewed-by: @kingxian,@kisnwang
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-04-13 19:33:42 +08:00 committed by Gitee
commit 7a67f28258
5 changed files with 14 additions and 4 deletions

View File

@ -40,6 +40,7 @@ def auc(x, y, reorder=False):
>>> metric.update(y_pred, y)
>>> fpr, tpr, thre = metric.eval()
>>> output = auc(fpr, tpr)
>>> print(output)
0.5357142857142857
"""
if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray):

View File

@ -32,14 +32,15 @@ class CosineSimilarity(Metric):
If sum or mean are used, then returns (b, 1) with the reduced value for each row
Example:
>>> test_data = np.random.randn(4, 8)
>>> test_data = np.array([[1, 3, 4, 7], [2, 4, 2, 5], [3, 1, 5, 8]])
>>> metric = CosineSimilarity()
>>> metric.clear()
>>> metric.update(test_data)
>>> square_matrix = metric.eval()
[[0. -0.14682831 0.19102288 -0.36204537]
...
]
>>> print(square_matrix)
[[0. 0.94025615 0.95162452]
[0.94025615 0. 0.86146098]
[0.95162452 0.86146098 0.]]
"""
def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True):
super().__init__()

View File

@ -64,6 +64,7 @@ class OcclusionSensitivity(Metric):
>>> metric.clear()
>>> metric.update(model, test_data, label)
>>> score = metric.eval()
>>> print(score)
[0.29999995 0.6 1 0.9]
"""
def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None):

View File

@ -39,6 +39,7 @@ class Perplexity(Metric):
>>> metric.clear()
>>> metric.update(x, y)
>>> perplexity = metric.eval()
>>> print(perplexity)
2.231443166940565
"""

View File

@ -38,8 +38,11 @@ class ROC(Metric):
>>> metric.clear()
>>> metric.update(x, y)
>>> fpr, tpr, thresholds = metric.eval()
>>> print(fpr)
[0. 0. 0.33333333 0.6666667 1.]
>>> print(tpr)
[0. 1. 1. 1. 1.]
>>> print(thresholds)
[5 4 3 2 1]
>>>
>>> # 2) multiclass classification example
@ -50,9 +53,12 @@ class ROC(Metric):
>>> metric.clear()
>>> metric.update(x, y)
>>> fpr, tpr, thresholds = metric.eval()
>>> print(fpr)
[array([0., 0., 0.33333333, 0.66666667, 1.]), array([0., 0.33333333, 0.33333333, 1.]),
array([0., 0.33333333, 1.]), array([0., 0., 1.])]
>>> print(tpr)
[array([0., 1., 1., 1., 1.]), array([0., 0., 1., 1.]), array([0., 1., 1.]), array([0., 1., 1.])]
print(thresholds)
[array([1.28, 0.28, 0.2, 0.1, 0.05]), array([1.55, 0.55, 0.2, 0.05]), array([1.15, 0.15, 0.05]),
array([1.75, 0.75, 0.05])]
"""