forked from mindspore-Ecosystem/mindspore
!10040 nn_notes
From: @bai-yangfan Reviewed-by: @kingxian,@c_34 Signed-off-by: @kingxian
This commit is contained in:
commit
51ed499f3d
|
@ -39,6 +39,8 @@ class MAE(Metric):
|
||||||
>>> error.clear()
|
>>> error.clear()
|
||||||
>>> error.update(x, y)
|
>>> error.update(x, y)
|
||||||
>>> result = error.eval()
|
>>> result = error.eval()
|
||||||
|
>>> print(result)
|
||||||
|
0.037499990314245224
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MAE, self).__init__()
|
super(MAE, self).__init__()
|
||||||
|
|
|
@ -36,10 +36,11 @@ class Fbeta(Metric):
|
||||||
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
|
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
|
||||||
>>> y = Tensor(np.array([1, 0, 1]))
|
>>> y = Tensor(np.array([1, 0, 1]))
|
||||||
>>> metric = nn.Fbeta(1)
|
>>> metric = nn.Fbeta(1)
|
||||||
|
>>> metric.clear()
|
||||||
>>> metric.update(x, y)
|
>>> metric.update(x, y)
|
||||||
>>> fbeta = metric.eval()
|
>>> fbeta = metric.eval()
|
||||||
>>> print(fbeta)
|
>>> print(fbeta)
|
||||||
[0.66666667 0.66666667]
|
[0.66666667 0.66666667]
|
||||||
"""
|
"""
|
||||||
def __init__(self, beta):
|
def __init__(self, beta):
|
||||||
super(Fbeta, self).__init__()
|
super(Fbeta, self).__init__()
|
||||||
|
@ -133,7 +134,7 @@ class F1(Fbeta):
|
||||||
>>> metric.update(x, y)
|
>>> metric.update(x, y)
|
||||||
>>> result = metric.eval()
|
>>> result = metric.eval()
|
||||||
>>> print(result)
|
>>> print(result)
|
||||||
[0.66666667 0.66666667]
|
[0.66666667 0.66666667]
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(F1, self).__init__(1.0)
|
super(F1, self).__init__(1.0)
|
||||||
|
|
|
@ -47,6 +47,9 @@ class Precision(EvaluationBase):
|
||||||
>>> metric.clear()
|
>>> metric.clear()
|
||||||
>>> metric.update(x, y)
|
>>> metric.update(x, y)
|
||||||
>>> precision = metric.eval()
|
>>> precision = metric.eval()
|
||||||
|
>>> print(precision)
|
||||||
|
[0.5 1. ]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, eval_type='classification'):
|
def __init__(self, eval_type='classification'):
|
||||||
super(Precision, self).__init__(eval_type)
|
super(Precision, self).__init__(eval_type)
|
||||||
|
|
|
@ -47,6 +47,8 @@ class Recall(EvaluationBase):
|
||||||
>>> metric.clear()
|
>>> metric.clear()
|
||||||
>>> metric.update(x, y)
|
>>> metric.update(x, y)
|
||||||
>>> recall = metric.eval()
|
>>> recall = metric.eval()
|
||||||
|
>>> print(recall)
|
||||||
|
[1. 0.5]
|
||||||
"""
|
"""
|
||||||
def __init__(self, eval_type='classification'):
|
def __init__(self, eval_type='classification'):
|
||||||
super(Recall, self).__init__(eval_type)
|
super(Recall, self).__init__(eval_type)
|
||||||
|
|
|
@ -276,7 +276,7 @@ class GetNextSingleOp(Cell):
|
||||||
>>> relu = P.ReLU()
|
>>> relu = P.ReLU()
|
||||||
>>> result = relu(data).asnumpy()
|
>>> result = relu(data).asnumpy()
|
||||||
>>> print(result.shape)
|
>>> print(result.shape)
|
||||||
>>> (32, 1, 32, 32)
|
(32, 1, 32, 32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_types, dataset_shapes, queue_name):
|
def __init__(self, dataset_types, dataset_shapes, queue_name):
|
||||||
|
@ -356,6 +356,7 @@ class WithEvalCell(Cell):
|
||||||
Args:
|
Args:
|
||||||
network (Cell): The network Cell.
|
network (Cell): The network Cell.
|
||||||
loss_fn (Cell): The loss Cell.
|
loss_fn (Cell): The loss Cell.
|
||||||
|
add_cast_fp32 (bool): Adjust the data type to float32.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||||
|
@ -410,7 +411,7 @@ class ParameterUpdate(Cell):
|
||||||
>>> param = network.parameters_dict()['weight']
|
>>> param = network.parameters_dict()['weight']
|
||||||
>>> update = nn.ParameterUpdate(param)
|
>>> update = nn.ParameterUpdate(param)
|
||||||
>>> update.phase = "update_param"
|
>>> update.phase = "update_param"
|
||||||
>>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32)
|
>>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32)
|
||||||
>>> network_updata = update(weight)
|
>>> network_updata = update(weight)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue