forked from mindspore-Ecosystem/mindspore
!10040 nn_notes
From: @bai-yangfan Reviewed-by: @kingxian,@c_34 Signed-off-by: @kingxian
This commit is contained in:
commit
51ed499f3d
|
@ -39,6 +39,8 @@ class MAE(Metric):
|
|||
>>> error.clear()
|
||||
>>> error.update(x, y)
|
||||
>>> result = error.eval()
|
||||
>>> print(result)
|
||||
0.037499990314245224
|
||||
"""
|
||||
def __init__(self):
|
||||
super(MAE, self).__init__()
|
||||
|
|
|
@ -36,6 +36,7 @@ class Fbeta(Metric):
|
|||
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
|
||||
>>> y = Tensor(np.array([1, 0, 1]))
|
||||
>>> metric = nn.Fbeta(1)
|
||||
>>> metric.clear()
|
||||
>>> metric.update(x, y)
|
||||
>>> fbeta = metric.eval()
|
||||
>>> print(fbeta)
|
||||
|
|
|
@ -47,6 +47,9 @@ class Precision(EvaluationBase):
|
|||
>>> metric.clear()
|
||||
>>> metric.update(x, y)
|
||||
>>> precision = metric.eval()
|
||||
>>> print(precision)
|
||||
[0.5 1. ]
|
||||
|
||||
"""
|
||||
def __init__(self, eval_type='classification'):
|
||||
super(Precision, self).__init__(eval_type)
|
||||
|
|
|
@ -47,6 +47,8 @@ class Recall(EvaluationBase):
|
|||
>>> metric.clear()
|
||||
>>> metric.update(x, y)
|
||||
>>> recall = metric.eval()
|
||||
>>> print(recall)
|
||||
[1. 0.5]
|
||||
"""
|
||||
def __init__(self, eval_type='classification'):
|
||||
super(Recall, self).__init__(eval_type)
|
||||
|
|
|
@ -276,7 +276,7 @@ class GetNextSingleOp(Cell):
|
|||
>>> relu = P.ReLU()
|
||||
>>> result = relu(data).asnumpy()
|
||||
>>> print(result.shape)
|
||||
>>> (32, 1, 32, 32)
|
||||
(32, 1, 32, 32)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_types, dataset_shapes, queue_name):
|
||||
|
@ -356,6 +356,7 @@ class WithEvalCell(Cell):
|
|||
Args:
|
||||
network (Cell): The network Cell.
|
||||
loss_fn (Cell): The loss Cell.
|
||||
add_cast_fp32 (bool): Adjust the data type to float32.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
|
@ -410,7 +411,7 @@ class ParameterUpdate(Cell):
|
|||
>>> param = network.parameters_dict()['weight']
|
||||
>>> update = nn.ParameterUpdate(param)
|
||||
>>> update.phase = "update_param"
|
||||
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
|
||||
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
|
||||
>>> network_updata = update(weight)
|
||||
"""
|
||||
|
||||
|
|
Loading…
Reference in New Issue