From: @bai-yangfan
Reviewed-by: @kingxian,@c_34
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2020-12-16 14:28:44 +08:00 committed by Gitee
commit 51ed499f3d
5 changed files with 13 additions and 4 deletions

View File

@ -39,6 +39,8 @@ class MAE(Metric):
>>> error.clear()
>>> error.update(x, y)
>>> result = error.eval()
>>> print(result)
0.037499990314245224
"""
def __init__(self):
super(MAE, self).__init__()

View File

@ -36,10 +36,11 @@ class Fbeta(Metric):
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Fbeta(1)
>>> metric.clear()
>>> metric.update(x, y)
>>> fbeta = metric.eval()
>>> print(fbeta)
[0.66666667 0.66666667]
[0.66666667 0.66666667]
"""
def __init__(self, beta):
super(Fbeta, self).__init__()
@ -133,7 +134,7 @@ class F1(Fbeta):
>>> metric.update(x, y)
>>> result = metric.eval()
>>> print(result)
[0.66666667 0.66666667]
[0.66666667 0.66666667]
"""
def __init__(self):
super(F1, self).__init__(1.0)

View File

@ -47,6 +47,9 @@ class Precision(EvaluationBase):
>>> metric.clear()
>>> metric.update(x, y)
>>> precision = metric.eval()
>>> print(precision)
[0.5 1. ]
"""
def __init__(self, eval_type='classification'):
super(Precision, self).__init__(eval_type)

View File

@ -47,6 +47,8 @@ class Recall(EvaluationBase):
>>> metric.clear()
>>> metric.update(x, y)
>>> recall = metric.eval()
>>> print(recall)
[1. 0.5]
"""
def __init__(self, eval_type='classification'):
super(Recall, self).__init__(eval_type)

View File

@ -276,7 +276,7 @@ class GetNextSingleOp(Cell):
>>> relu = P.ReLU()
>>> result = relu(data).asnumpy()
>>> print(result.shape)
>>> (32, 1, 32, 32)
(32, 1, 32, 32)
"""
def __init__(self, dataset_types, dataset_shapes, queue_name):
@ -356,6 +356,7 @@ class WithEvalCell(Cell):
Args:
network (Cell): The network Cell.
loss_fn (Cell): The loss Cell.
add_cast_fp32 (bool): Adjust the data type to float32.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
@ -410,7 +411,7 @@ class ParameterUpdate(Cell):
>>> param = network.parameters_dict()['weight']
>>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param"
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
>>> network_updata = update(weight)
"""