!126 resolve some issues in nn comments

Merge pull request !126 from zhongligeng/master
This commit is contained in:
mindspore-ci-bot 2020-04-06 10:14:50 +08:00 committed by Gitee
commit 32017f6da3
25 changed files with 128 additions and 150 deletions

View File

@ -65,7 +65,7 @@ class Dropout(Cell):
Tensor, output tensor with the same shape as the input. Tensor, output tensor with the same shape as the input.
Examples: Examples:
>>> x = mindspore.Tensor(np.ones([20, 16, 50]), mindspore.float32) >>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
>>> net = nn.Dropout(keep_prob=0.8) >>> net = nn.Dropout(keep_prob=0.8)
>>> net(x) >>> net(x)
""" """
@ -111,7 +111,7 @@ class Flatten(Cell):
Examples: Examples:
>>> net = nn.Flatten() >>> net = nn.Flatten()
>>> input = mindspore.Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32) >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
>>> input.shape() >>> input.shape()
(2, 2, 2) (2, 2, 2)
>>> net(input) >>> net(input)
@ -149,9 +149,6 @@ class Dense(Cell):
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
Returns:
Tensor, output tensor.
Raises: Raises:
ValueError: If weight_init or bias_init shape is incorrect. ValueError: If weight_init or bias_init shape is incorrect.
@ -163,7 +160,7 @@ class Dense(Cell):
Examples: Examples:
>>> net = nn.Dense(3, 4) >>> net = nn.Dense(3, 4)
>>> input = mindspore.Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input) >>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]] [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
@ -243,8 +240,8 @@ class ClipByNorm(Cell):
Examples: Examples:
>>> net = nn.ClipByNorm() >>> net = nn.ClipByNorm()
>>> input = mindspore.Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
>>> clip_norm = mindspore.Tensor(np.array([100]).astype(np.float32)) >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
>>> net(input, clip_norm) >>> net(input, clip_norm)
""" """
@ -290,9 +287,6 @@ class Norm(Cell):
keep_dims (bool): If True, the axis indicated in `axis` are kept with size 1. Otherwise, keep_dims (bool): If True, the axis indicated in `axis` are kept with size 1. Otherwise,
the dimensions in `axis` are removed from the output shape. Default: False. the dimensions in `axis` are removed from the output shape. Default: False.
Returns:
Tensor, a Tensor of the same type as input, containing the vector or matrix norms.
Inputs: Inputs:
- **input** (Tensor) - Tensor which is not empty. - **input** (Tensor) - Tensor which is not empty.
@ -302,7 +296,7 @@ class Norm(Cell):
Examples: Examples:
>>> net = nn.Norm(axis=0) >>> net = nn.Norm(axis=0)
>>> input = mindspore.Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
def __init__(self, axis=(), keep_dims=False): def __init__(self, axis=(), keep_dims=False):
@ -344,7 +338,8 @@ class OneHot(Cell):
when indices[j] = i. Default: 1.0. when indices[j] = i. Default: 1.0.
off_value (float): A scalar defining the value to fill in output[i][j] off_value (float): A scalar defining the value to fill in output[i][j]
when indices[j] != i. Default: 0.0. when indices[j] != i. Default: 0.0.
dtype (:class:`mindspore.dtype`): Default: mindspore.float32. dtype (:class:`mindspore.dtype`): Data type of 'on_value' and 'off_value', not the
data type of indices. Default: mindspore.float32.
Inputs: Inputs:
- **indices** (Tensor) - A tensor of indices of data type mindspore.int32 and arbitrary shape. - **indices** (Tensor) - A tensor of indices of data type mindspore.int32 and arbitrary shape.
@ -355,7 +350,7 @@ class OneHot(Cell):
Examples: Examples:
>>> net = nn.OneHot(depth=4, axis=1) >>> net = nn.OneHot(depth=4, axis=1)
>>> indices = mindspore.Tensor([[1, 3], [0, 2]], dtype=mindspore.int32) >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
>>> net(indices) >>> net(indices)
[[[0. 0.] [[[0. 0.]
[1. 0.] [1. 0.]

View File

@ -86,7 +86,7 @@ class SequentialCell(Cell):
>>> relu = nn.ReLU() >>> relu = nn.ReLU()
>>> seq = nn.SequentialCell([conv, bn, relu]) >>> seq = nn.SequentialCell([conv, bn, relu])
>>> >>>
>>> x = mindspore.Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32) >>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
>>> seq(x) >>> seq(x)
[[[[0.02531557 0. ] [[[[0.02531557 0. ]
[0.04933941 0.04880078]] [0.04933941 0.04880078]]
@ -138,7 +138,6 @@ class SequentialCell(Cell):
return len(self._cells) return len(self._cells)
def construct(self, input_data): def construct(self, input_data):
"""Processes the input with the defined sequence of Cells."""
for cell in self.cell_list: for cell in self.cell_list:
input_data = cell(input_data) input_data = cell(input_data)
return input_data return input_data
@ -161,7 +160,7 @@ class CellList(_CellListBase, Cell):
>>> cell_ls = nn.CellList([bn]) >>> cell_ls = nn.CellList([bn])
>>> cell_ls.insert(0, conv) >>> cell_ls.insert(0, conv)
>>> cell_ls.append(relu) >>> cell_ls.append(relu)
>>> x = mindspore.Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32) >>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
>>> # not same as nn.SequentialCell, `cell_ls(x)` is not correct >>> # not same as nn.SequentialCell, `cell_ls(x)` is not correct
>>> cell_ls >>> cell_ls
CellList< (0): Conv2d<input_channels=100, ..., bias_init=None> CellList< (0): Conv2d<input_channels=100, ..., bias_init=None>

View File

@ -146,9 +146,6 @@ class Conv2d(_Conv):
Initializer and string are the same as 'weight_init'. Refer to the values of Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'. Initializer for more details. Default: 'zeros'.
Returns:
Tensor, output tensor.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -157,7 +154,7 @@ class Conv2d(_Conv):
Examples: Examples:
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = mindspore.Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape() >>> net(input).shape()
(1, 240, 1024, 640) (1, 240, 1024, 640)
""" """
@ -277,7 +274,7 @@ class Conv2dTranspose(_Conv):
Examples: Examples:
>>> net = nn.Conv2dTranspose(3, 64, 4, has_bias=False, weight_init='normal') >>> net = nn.Conv2dTranspose(3, 64, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 3, 16, 50]), mstype.float32) >>> input = Tensor(np.ones([1, 3, 16, 50]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
def __init__(self, def __init__(self,

View File

@ -50,7 +50,7 @@ class Embedding(Cell):
Examples: Examples:
>>> net = nn.Embedding(20000, 768, True) >>> net = nn.Embedding(20000, 768, True)
>>> input_data = mindspore.Tensor(np.ones([8, 128]), mindspore.int32) >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
>>> >>>
>>> # Maps the input word IDs to word embedding. >>> # Maps the input word IDs to word embedding.
>>> output = net(input_data) >>> output = net(input_data)

View File

@ -96,9 +96,9 @@ class LSTM(Cell):
>>> return self.lstm(inp, (h0, c0)) >>> return self.lstm(inp, (h0, c0))
>>> >>>
>>> net = LstmNet(10, 12, 2, has_bias=True, batch_first=True, bidirectional=False) >>> net = LstmNet(10, 12, 2, has_bias=True, batch_first=True, bidirectional=False)
>>> input = mindspore.Tensor(np.ones([3, 5, 10]).astype(np.float32)) >>> input = Tensor(np.ones([3, 5, 10]).astype(np.float32))
>>> h0 = mindspore.Tensor(np.ones([1 * 2, 3, 12]).astype(np.float32)) >>> h0 = Tensor(np.ones([1 * 2, 3, 12]).astype(np.float32))
>>> c0 = mindspore.Tensor(np.ones([1 * 2, 3, 12]).astype(np.float32)) >>> c0 = Tensor(np.ones([1 * 2, 3, 12]).astype(np.float32))
>>> output, (hn, cn) = net(input, h0, c0) >>> output, (hn, cn) = net(input, h0, c0)
""" """
def __init__(self, def __init__(self,

View File

@ -159,7 +159,7 @@ class BatchNorm1d(_BatchNorm):
Examples: Examples:
>>> net = nn.BatchNorm1d(num_features=16) >>> net = nn.BatchNorm1d(num_features=16)
>>> input = mindspore.Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
def _check_data_dim(self, x): def _check_data_dim(self, x):
@ -258,7 +258,7 @@ class LayerNorm(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10], np.float32)) >>> x = Tensor(np.ones([20, 5, 10, 10], np.float32))
>>> shape1 = x.shape()[1:] >>> shape1 = x.shape()[1:]
>>> m = LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x) >>> m(x)
""" """
def __init__(self, def __init__(self,

View File

@ -63,8 +63,8 @@ class MaxPool2d(_PoolNd):
pad_mode for training only supports "same" and "valid". pad_mode for training only supports "same" and "valid".
Args: Args:
kernel_size (int): Size of the window to take a max over. kernel_size (int): Size of the window to take a max over. Default 1.
stride (int): Stride size of the window. Default: None. stride (int): Stride size of the window. Default: 1.
pad_mode (str): Select the mode of the pad. The optional values are pad_mode (str): Select the mode of the pad. The optional values are
"same" and "valid". Default: "valid". "same" and "valid". Default: "valid".
@ -75,7 +75,7 @@ class MaxPool2d(_PoolNd):
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return - valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. without padding. Extra pixels will be discarded.
padding (int): Now is not supported, mplicit zero padding to be added on both sides. Default: 0. padding (int): Implicit zero padding to be added on both sides. Default: 0.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -85,7 +85,7 @@ class MaxPool2d(_PoolNd):
Examples: Examples:
>>> pool = MaxPool2d(kernel_size=3, stride=1) >>> pool = MaxPool2d(kernel_size=3, stride=1)
>>> x = mindspore.Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
[[[[1. 5. 5. 1.] [[[[1. 5. 5. 1.]
[0. 3. 4. 8.] [0. 3. 4. 8.]
[4. 2. 7. 6.] [4. 2. 7. 6.]
@ -149,8 +149,8 @@ class AvgPool2d(_PoolNd):
pad_mode for training only supports "same" and "valid". pad_mode for training only supports "same" and "valid".
Args: Args:
kernel_size (int): Size of the window to take a max over. kernel_size (int): Size of the window to take a max over. Default: 1.
stride (int): Stride size of the window. Default: None. stride (int): Stride size of the window. Default: 1.
pad_mode (str): Select the mode of the pad. The optional values are pad_mode (str): Select the mode of the pad. The optional values are
"same", "valid". Default: "valid". "same", "valid". Default: "valid".
@ -161,7 +161,7 @@ class AvgPool2d(_PoolNd):
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return - valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. without padding. Extra pixels will be discarded.
padding (int): Now is not supported, implicit zero padding to be added on both sides. Default: 0. padding (int): Implicit zero padding to be added on both sides. Default: 0.
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -171,7 +171,7 @@ class AvgPool2d(_PoolNd):
Examples: Examples:
>>> pool = AvgPool2d(kernel_size=3, stride=1) >>> pool = AvgPool2d(kernel_size=3, stride=1)
>>> x = mindspore.Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
[[[[5. 5. 9. 9.] [[[[5. 5. 9. 9.]
[8. 4. 3. 0.] [8. 4. 3. 0.]
[2. 7. 1. 2.] [2. 7. 1. 2.]

View File

@ -86,9 +86,9 @@ class L1Loss(_Loss):
Tensor, loss float tensor. Tensor, loss float tensor.
Examples: Examples:
>>> loss = L1Loss() >>> loss = nn.L1Loss()
>>> input_data = Tensor(np.array([1, 2, 3]), mstype.float32) >>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> target_data = Tensor(np.array([1, 2, 2]), mstype.float32) >>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> loss(input_data, target_data) >>> loss(input_data, target_data)
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
@ -126,9 +126,9 @@ class MSELoss(_Loss):
Tensor, weighted loss float tensor. Tensor, weighted loss float tensor.
Examples: Examples:
>>> loss = MSELoss() >>> loss = nn.MSELoss()
>>> input_data = Tensor(np.array([1, 2, 3]), mstype.float32) >>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> target_data = Tensor(np.array([1, 2, 2]), mstype.float32) >>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> loss(input_data, target_data) >>> loss(input_data, target_data)
""" """
def construct(self, base, target): def construct(self, base, target):
@ -171,9 +171,9 @@ class SmoothL1Loss(_Loss):
Tensor, loss float tensor. Tensor, loss float tensor.
Examples: Examples:
>>> loss = SmoothL1Loss() >>> loss = nn.SmoothL1Loss()
>>> input_data = Tensor(np.array([1, 2, 3]), mstype.float32) >>> input_data = Tensor(np.array([1, 2, 3]), mindspore.float32)
>>> target_data = Tensor(np.array([1, 2, 2]), mstype.float32) >>> target_data = Tensor(np.array([1, 2, 2]), mindspore.float32)
>>> loss(input_data, target_data) >>> loss(input_data, target_data)
""" """
def __init__(self, sigma=1.0): def __init__(self, sigma=1.0):
@ -219,17 +219,16 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
Inputs: Inputs:
- **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`. - **logits** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
- **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`. If `sparse` is True, The type of - **labels** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`. If `sparse` is True, The type of
`labels` is mstype.int32. If `sparse` is False, the type of `labels` is same as the type of `logits`. `labels` is mindspore.int32. If `sparse` is False, the type of `labels` is same as the type of `logits`.
Outputs: Outputs:
Tensor, a tensor of the same shape as logits with the component-wise Tensor, a tensor of the same shape as logits with the component-wise
logistic losses. logistic losses.
Examples: Examples:
>>> loss = SoftmaxCrossEntropyWithLogits(sparse=True) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> logits = Tensor(np.random.randint(0, 9, [1, 10]), mstype.float32) >>> logits = Tensor(np.random.randint(0, 9, [1, 10]), mindspore.float32)
>>> labels_np = np.zeros([1, 10]).astype(np.int32) >>> labels_np = np.ones([1,]).astype(np.int32)
>>> labels_np[0][0] = 1
>>> labels = Tensor(labels_np) >>> labels = Tensor(labels_np)
>>> loss(logits, labels) >>> loss(logits, labels)
""" """
@ -286,8 +285,8 @@ class SoftmaxCrossEntropyExpand(Cell):
Examples: Examples:
>>> loss = SoftmaxCrossEntropyExpand(sparse=True) >>> loss = SoftmaxCrossEntropyExpand(sparse=True)
>>> input_data = Tensor(np.ones([64, 512]), dtype=mstype.float32) >>> input_data = Tensor(np.ones([64, 512]), dtype=mindspore.float32)
>>> label = Tensor(np.ones([64]), dtype=mstype.int32) >>> label = Tensor(np.ones([64]), dtype=mindspore.int32)
>>> loss(input_data, label) >>> loss(input_data, label)
""" """
def __init__(self, sparse=False): def __init__(self, sparse=False):

View File

@ -35,8 +35,8 @@ class Accuracy(EvaluationBase):
Default: 'classification'. Default: 'classification'.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([1, 0, 1]), mindspore.float32) >>> y = Tensor(np.array([1, 0, 1]), mindspore.float32)
>>> metric = nn.Accuracy('classification') >>> metric = nn.Accuracy('classification')
>>> metric.clear() >>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
@ -58,13 +58,14 @@ class Accuracy(EvaluationBase):
Args: Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array. inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array.
`y_pred` is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]` For 'classification' evaluation type, `y_pred` is in most cases (not strictly) a list
of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. For 'multilabel' evaluation type, `y_pred` can only be one-hot is the number of categories. Shape of `y` can be :math:`(N, C)` with values 0 and 1 if one-hot
encoding with values 0 or 1. Indices with 1 indicate positive category. `y` contains values encoding is used or the shape is :math:`(N,)` with integer values if index of category is used.
of integers. The shape is :math:`(N, C)` if one-hot encoding is used. One-hot encoding For 'multilabel' evaluation type, `y_pred` and `y` can only be one-hot encoding with
should be used when 'eval_type' is 'multilabel'. Shape can also be :math:`(N, 1)` if category values 0 or 1. Indices with 1 indicate positive category. The shape of `y_pred` and `y`
index is used in 'classification' evaluation type. are both :math:`(N, C)`.
Raises: Raises:
ValueError: If the number of the input is not 2. ValueError: If the number of the input is not 2.

View File

@ -33,8 +33,8 @@ class MAE(Metric):
The method `update` must be called with the form `update(y_pred, y)`. The method `update` must be called with the form `update(y_pred, y)`.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32) >>> x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32) >>> y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]), mindspore.float32)
>>> error = nn.MAE() >>> error = nn.MAE()
>>> error.clear() >>> error.clear()
>>> error.update(x, y) >>> error.update(x, y)
@ -95,8 +95,8 @@ class MSE(Metric):
where :math:`n` is batch size. where :math:`n` is batch size.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32) >>> x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([0.1, 0.25, 0.5, 0.9]), mindspore.float32) >>> y = Tensor(np.array([0.1, 0.25, 0.5, 0.9]), mindspore.float32)
>>> error = MSE() >>> error = MSE()
>>> error.clear() >>> error.clear()
>>> error.update(x, y) >>> error.update(x, y)

View File

@ -33,12 +33,11 @@ class Fbeta(Metric):
beta (float): The weight of precision. beta (float): The weight of precision.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = mindspore.Tensor(np.array([1, 0, 1])) >>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Fbeta(1) >>> metric = nn.Fbeta(1)
>>> metric.update(x, y) >>> metric.update(x, y)
>>> fbeta = metric.eval() >>> fbeta = metric.eval()
[0.66666667 0.66666667]
""" """
def __init__(self, beta): def __init__(self, beta):
super(Fbeta, self).__init__() super(Fbeta, self).__init__()
@ -64,7 +63,7 @@ class Fbeta(Metric):
`y_pred` is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]` `y_pred` is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. y contains values of integers. The shape is :math:`(N, C)` is the number of categories. y contains values of integers. The shape is :math:`(N, C)`
if one-hot encoding is used. Shape can also be :math:`(N, 1)` if category index is used. if one-hot encoding is used. Shape can also be :math:`(N,)` if category index is used.
""" """
if len(inputs) != 2: if len(inputs) != 2:
raise ValueError('Fbeta need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) raise ValueError('Fbeta need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
@ -126,8 +125,8 @@ class F1(Fbeta):
F_\beta=\frac{2\cdot true\_positive}{2\cdot true\_positive + false\_negative + false\_positive} F_\beta=\frac{2\cdot true\_positive}{2\cdot true\_positive + false\_negative + false\_positive}
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = mindspore.Tensor(np.array([1, 0, 1])) >>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.F1() >>> metric = nn.F1()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> fbeta = metric.eval() >>> fbeta = metric.eval()

View File

@ -25,12 +25,11 @@ class Loss(Metric):
loss = \frac{\sum_{k=1}^{n}loss_k}{n} loss = \frac{\sum_{k=1}^{n}loss_k}{n}
Examples: Examples:
>>> x = mindspore.Tensor(np.array(0.2), mindspore.float32) >>> x = Tensor(np.array(0.2), mindspore.float32)
>>> loss = nn.Loss() >>> loss = nn.Loss()
>>> loss.clear() >>> loss.clear()
>>> loss.update(x) >>> loss.update(x)
>>> result = loss.eval() >>> result = loss.eval()
0.20000000298023224
""" """
def __init__(self): def __init__(self):
super(Loss, self).__init__() super(Loss, self).__init__()

View File

@ -41,13 +41,12 @@ class Precision(EvaluationBase):
multilabel. Default: 'classification'. multilabel. Default: 'classification'.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = mindspore.Tensor(np.array([1, 0, 1])) >>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Precision('classification') >>> metric = nn.Precision('classification')
>>> metric.clear() >>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> precision = metric.eval() >>> precision = metric.eval()
[0.5 1. ]
""" """
def __init__(self, eval_type='classification'): def __init__(self, eval_type='classification'):
super(Precision, self).__init__(eval_type) super(Precision, self).__init__(eval_type)
@ -72,13 +71,14 @@ class Precision(EvaluationBase):
Args: Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
`y_pred` is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]` For 'classification' evaluation type, `y_pred` is in most cases (not strictly) a list
of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. For 'multilabel' evaluation type, `y_pred` can only be one-hot is the number of categories. Shape of `y` can be :math:`(N, C)` with values 0 and 1 if one-hot
encoding with values 0 or 1. Indices with 1 indicate positive category. `y` contains values encoding is used or the shape is :math:`(N,)` with integer values if index of category is used.
of integers. The shape is :math:`(N, C)` if one-hot encoding is used. One-hot encoding For 'multilabel' evaluation type, `y_pred` and `y` can only be one-hot encoding with
should be used when 'eval_type' is 'multilabel'. Shape can also be :math:`(N, 1)` if category values 0 or 1. Indices with 1 indicate positive category. The shape of `y_pred` and `y`
index is used in 'classification' evaluation type. are both :math:`(N, C)`.
Raises: Raises:
ValueError: If the number of input is not 2. ValueError: If the number of input is not 2.

View File

@ -41,13 +41,12 @@ class Recall(EvaluationBase):
multilabel. Default: 'classification'. multilabel. Default: 'classification'.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = mindspore.Tensor(np.array([1, 0, 1])) >>> y = Tensor(np.array([1, 0, 1]))
>>> metric = nn.Recall('classification') >>> metric = nn.Recall('classification')
>>> metric.clear() >>> metric.clear()
>>> metric.update(x, y) >>> metric.update(x, y)
>>> recall = metric.eval() >>> recall = metric.eval()
[1. 0.5]
""" """
def __init__(self, eval_type='classification'): def __init__(self, eval_type='classification'):
super(Recall, self).__init__(eval_type) super(Recall, self).__init__(eval_type)
@ -72,13 +71,14 @@ class Recall(EvaluationBase):
Args: Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array. inputs: Input `y_pred` and `y`. `y_pred` and `y` are a `Tensor`, a list or an array.
`y_pred` is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]` For 'classification' evaluation type, `y_pred` is in most cases (not strictly) a list
of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. For 'multilabel' evaluation type, `y_pred` can only be one-hot is the number of categories. Shape of `y` can be :math:`(N, C)` with values 0 and 1 if one-hot
encoding with values 0 or 1. Indices with 1 indicate positive category. `y` contains values encoding is used or the shape is :math:`(N,)` with integer values if index of category is used.
of integers. The shape is :math:`(N, C)` if one-hot encoding is used. One-hot encoding For 'multilabel' evaluation type, `y_pred` and `y` can only be one-hot encoding with
should be used when 'eval_type' is 'multilabel'. Shape can also be :math:`(N, 1)` if category values 0 or 1. Indices with 1 indicate positive category. The shape of `y_pred` and `y`
index is used in 'classification' evaluation type. are both :math:`(N, C)`.
Raises: Raises:

View File

@ -33,14 +33,13 @@ class TopKCategoricalAccuracy(Metric):
ValueError: If `k` is less than 1. ValueError: If `k` is less than 1.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.], >>> x = Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.],
>>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32) >>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([2, 0, 1]), mindspore.float32) >>> y = Tensor(np.array([2, 0, 1]), mindspore.float32)
>>> topk = nn.TopKCategoricalAccuracy(3) >>> topk = nn.TopKCategoricalAccuracy(3)
>>> topk.clear() >>> topk.clear()
>>> topk.update(x, y) >>> topk.update(x, y)
>>> result = topk.eval() >>> result = topk.eval()
0.6666666666666666
""" """
def __init__(self, k): def __init__(self, k):
super(TopKCategoricalAccuracy, self).__init__() super(TopKCategoricalAccuracy, self).__init__()
@ -65,7 +64,7 @@ class TopKCategoricalAccuracy(Metric):
y_pred is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]` y_pred is in most cases (not strictly) a list of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. y contains values of integers. The shape is :math:`(N, C)` is the number of categories. y contains values of integers. The shape is :math:`(N, C)`
if one-hot encoding is used. Shape can also be :math:`(N, 1)` if category index is used. if one-hot encoding is used. Shape can also be :math:`(N,)` if category index is used.
""" """
if len(inputs) != 2: if len(inputs) != 2:
raise ValueError('Topk need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) raise ValueError('Topk need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
@ -98,9 +97,9 @@ class Top1CategoricalAccuracy(TopKCategoricalAccuracy):
Refer to class 'TopKCategoricalAccuracy' for more details. Refer to class 'TopKCategoricalAccuracy' for more details.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.], >>> x = Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.],
>>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32) >>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([2, 0, 1]), mindspore.float32) >>> y = Tensor(np.array([2, 0, 1]), mindspore.float32)
>>> topk = nn.Top1CategoricalAccuracy() >>> topk = nn.Top1CategoricalAccuracy()
>>> topk.clear() >>> topk.clear()
>>> topk.update(x, y) >>> topk.update(x, y)
@ -116,9 +115,9 @@ class Top5CategoricalAccuracy(TopKCategoricalAccuracy):
Refer to class 'TopKCategoricalAccuracy' for more details. Refer to class 'TopKCategoricalAccuracy' for more details.
Examples: Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.], >>> x = Tensor(np.array([[0.2, 0.5, 0.3, 0.6, 0.2], [0.1, 0.35, 0.5, 0.2, 0.],
>>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32) >>> [0.9, 0.6, 0.2, 0.01, 0.3]]), mindspore.float32)
>>> y = mindspore.Tensor(np.array([2, 0, 1]), mindspore.float32) >>> y = Tensor(np.array([2, 0, 1]), mindspore.float32)
>>> topk = nn.Top5CategoricalAccuracy() >>> topk = nn.Top5CategoricalAccuracy()
>>> topk.clear() >>> topk.clear()
>>> topk.update(x, y) >>> topk.update(x, y)

View File

@ -161,7 +161,7 @@ class Adam(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = Adam(params=net.trainable_params()) >>> optim = nn.Adam(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
@ -252,7 +252,7 @@ class AdamWeightDecay(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = AdamWeightDecay(params=net.trainable_params()) >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
@ -306,7 +306,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10) >>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, def __init__(self,

View File

@ -87,7 +87,7 @@ class FTRL(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = FTRL(net.trainable_params()) >>> opt = nn.FTRL(net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=opt, metrics=None)
""" """
def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,

View File

@ -163,7 +163,7 @@ class Lamb(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = Lamb(params=net.trainable_params(), decay_steps=10) >>> optim = nn.Lamb(params=net.trainable_params(), decay_steps=10)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """

View File

@ -90,8 +90,8 @@ class LARS(Cell):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = Momentum(net.trainable_params(), 0.1, 0.9) >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
>>> opt_lars = LARS(opt, epsilon=1e-08, hyperpara=0.02) >>> opt_lars = nn.LARS(opt, epsilon=1e-08, hyperpara=0.02)
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
""" """

View File

@ -83,7 +83,7 @@ class Momentum(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0,

View File

@ -132,7 +132,7 @@ class RMSProp(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = RMSProp(params=net.trainable_params(), learning_rate=lr) >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr)
>>> model = Model(net, loss, opt) >>> model = Model(net, loss, opt)
""" """
def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10,

View File

@ -77,7 +77,7 @@ class SGD(Optimizer):
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = SGD(params=net.trainable_params()) >>> optim = nn.SGD(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
""" """
def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False,

View File

@ -50,8 +50,8 @@ class WithLossCell(Cell):
>>> net_with_criterion = nn.WithLossCell(net, loss_fn) >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
>>> >>>
>>> batch_size = 2 >>> batch_size = 2
>>> data = mindspore.Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01) >>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
>>> label = mindspore.Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32)) >>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
>>> >>>
>>> net_with_criterion(data, label) >>> net_with_criterion(data, label)
""" """
@ -62,16 +62,6 @@ class WithLossCell(Cell):
self._loss_fn = loss_fn self._loss_fn = loss_fn
def construct(self, data, label): def construct(self, data, label):
"""
Computes loss based on the wrapped loss cell.
Args:
data (Tensor): Tensor data to train.
label (Tensor): Tensor label data.
Returns:
Tensor, compute result.
"""
out = self._backbone(data) out = self._backbone(data)
return self._loss_fn(out, label) return self._loss_fn(out, label)
@ -137,19 +127,6 @@ class WithGradCell(Cell):
self.network_with_loss.set_train() self.network_with_loss.set_train()
def construct(self, data, label): def construct(self, data, label):
"""
Computes gradients based on the wrapped gradients cell.
Note:
Run in PyNative mode.
Args:
data (Tensor): Tensor data to train.
label (Tensor): Tensor label data.
Returns:
Tensor, return compute gradients.
"""
weights = self.weights weights = self.weights
if self.sens is None: if self.sens is None:
grads = self.grad(self.network_with_loss, weights)(data, label) grads = self.grad(self.network_with_loss, weights)(data, label)
@ -355,7 +332,7 @@ class ParameterUpdate(Cell):
>>> param = network.parameters_dict()['learning_rate'] >>> param = network.parameters_dict()['learning_rate']
>>> update = nn.ParameterUpdate(param) >>> update = nn.ParameterUpdate(param)
>>> update.phase = "update_param" >>> update.phase = "update_param"
>>> lr = mindspore.Tensor(0.001, mindspore.float32) >>> lr = Tensor(0.001, mindspore.float32)
>>> update(lr) >>> update(lr)
""" """

View File

@ -120,25 +120,36 @@ class DistributedGradReducer(Cell):
ValueError: If degree is not a int or less than 0. ValueError: If degree is not a int or less than 0.
Examples: Examples:
>>> from mindspore.communication import get_group_size >>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C >>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P >>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>> >>>
>>> class TrainingWrapper(nn.Cell): >>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0): >>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False) >>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network >>> self.network = network
>>> self.weights = mindspore.ParameterTuple(network.trainable_params()) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens
>>> self.reducer_flag = False >>> self.reducer_flag = False
>>> self.grad_reducer = None >>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [mindspore.ParallelMode.DATA_PARALLEL, >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> mindspore.ParallelMode.HYBRID_PARALLEL]: >>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True >>> self.reducer_flag = True
>>> if self.reducer_flag: >>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean") >>> mean = context.get_auto_parallel_context("mirror_mean")
@ -161,8 +172,8 @@ class DistributedGradReducer(Cell):
>>> network = Net() >>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer) >>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = mindspore.Tensor(np.ones([16, 16]).astype(np.float32)) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = mindspore.Tensor(np.zeros([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label) >>> grads = train_cell(inputs, label)
""" """

View File

@ -65,9 +65,10 @@ class DynamicLossScaleUpdateCell(Cell):
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train() >>> train_network.set_train()
>>> >>>
>>> inputs = mindspore.Tensor(np.ones([16, 16]).astype(np.float32)) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = mindspore.Tensor(np.zeros([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> output = train_network(inputs, label) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
""" """
def __init__(self, def __init__(self,
@ -126,13 +127,14 @@ class FixedLossScaleUpdateCell(Cell):
Examples: Examples:
>>> net_with_loss = Net() >>> net_with_loss = Net()
>>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train() >>> train_network.set_train()
>>> >>>
>>> inputs = mindspore.Tensor(np.ones([16, 16]).astype(np.float32)) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = mindspore.Tensor(np.zeros([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> output = train_network(inputs, label) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens)
""" """
def __init__(self, loss_scale_value): def __init__(self, loss_scale_value):
@ -181,9 +183,9 @@ class TrainOneStepWithLossScaleCell(Cell):
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_update_cell=manager)
>>> train_network.set_train() >>> train_network.set_train()
>>> >>>
>>> inputs = mindspore.Tensor(np.ones([16, 16]).astype(np.float32)) >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = mindspore.Tensor(np.zeros([16, 16]).astype(np.float32)) >>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> scaling_sens = mindspore.Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32)
>>> output = train_network(inputs, label, scaling_sens) >>> output = train_network(inputs, label, scaling_sens)
""" """