1. add Note refer to nn.SGD for detail

2. delete default value of stat
3. delete examples
4. some comments error from wangting review
5. modify comments from jinyaohui
6. modify examples from wanghao
7. modify Select operation examples
This commit is contained in:
万万没想到 2020-03-23 15:33:01 +08:00
parent 0830ad932b
commit 605d980305
10 changed files with 27 additions and 26 deletions

View File

@ -40,7 +40,7 @@ class Softmax(Cell):
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Args:
axis (Union[int, tuple[int]]): The axis to apply Softmax operation. Default: -1, means the last dimension.
axis (Union[int, tuple[int]]): The axis to apply Softmax operation, -1 means the last dimension. Default: -1.
Inputs:
- **x** (Tensor) - The input of Softmax.
@ -70,7 +70,7 @@ class LogSoftmax(Cell):
where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
Args:
axis (int): The axis to apply LogSoftmax operation. Default: -1, means the last dimension.
axis (int): The axis to apply LogSoftmax operation, -1 means the last dimension. Default: -1.
Inputs:
- **x** (Tensor) - The input of LogSoftmax.

View File

@ -32,13 +32,13 @@ class Dropout(Cell):
r"""
Dropout layer for the input.
Randomly set some elements of the input tensor to zero with probability :math:`1 - keep_prob` during training
Randomly set some elements of the input tensor to zero with probability :math:`1 - keep\_prob` during training
using samples from a Bernoulli distribution.
Note:
Each channel will be zeroed out independently on every construct call.
The outputs are scaled by a factor of :math:`\frac{1}{keep_prob}` during training so
The outputs are scaled by a factor of :math:`\frac{1}{keep\_prob}` during training so
that the output layer remains at a similar scale. During inference, this
layer returns the same tensor as the input.

View File

@ -241,7 +241,7 @@ class Conv2dTranspose(_Conv):
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
kernel_size (Union[int, tuple]): int or tuple with 2 integers, which specifies the height
and width of the 2D convolution window.Single int means the value if for both height and width of
and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies the same value for all spatial dimensions. Default: 1.

View File

@ -26,8 +26,8 @@ class Fbeta(Metric):
Fbeta score is a weighted mean of precison and recall.
.. math::
F_\beta=\frac{(1+\beta^2) \cdot true positive}
{(1+\beta^2) \cdot true positive +\beta^2 \cdot false negative + false positive}
F_\beta=\frac{(1+\beta^2) \cdot true\_positive}
{(1+\beta^2) \cdot true\_positive +\beta^2 \cdot false\_negative + false\_positive}
Args:
beta (float): The weight of precision.
@ -123,7 +123,7 @@ class F1(Fbeta):
Refer to class `Fbeta` for more details.
.. math::
F_\beta=\frac{2\cdot true positive}{2\cdot true positive + false negative + false positive}
F_\beta=\frac{2\cdot true\_positive}{2\cdot true\_positive + false\_negative + false\_positive}
Examples:
>>> x = mindspore.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))

View File

@ -881,7 +881,7 @@ class ScalarToTensor(PrimitiveWithInfer):
Inputs:
- **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
- **dtype** (mindspore.dtype) - The target data type. Default: mindspore.float32. Only
constant value is allowed.
constant value is allowed.
Outputs:
Tensor. 0-D Tensor and the content is the input.
@ -1458,7 +1458,10 @@ class Select(PrimitiveWithInfer):
Examples:
>>> select = Select()
>>> select([True, False],[2,3],[1,2])
>>> input_x = Tensor([True, False])
>>> input_y = Tensor([2,3], mindspore.float32)
>>> input_z = Tensor([1,2], mindspore.float32)
>>> select(input_x, input_y, input_z)
"""
@prim_attr_register

View File

@ -66,11 +66,12 @@ class AllReduce(PrimitiveWithInfer):
Examples:
>>> from mindspore.communication.management import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.allreduce_sum = AllReduce(ReduceOp.SUM, group="nccl_world_group")
>>> self.allreduce_sum = P.AllReduce(ReduceOp.SUM, group="nccl_world_group")
>>>
>>> def construct(self, x):
>>> return self.allreduce_sum(x)
@ -130,11 +131,12 @@ class AllGather(PrimitiveWithInfer):
Examples:
>>> from mindspore.communication.management import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.allgather = AllGather(group="nccl_world_group")
>>> self.allgather = P.AllGather(group="nccl_world_group")
>>>
>>> def construct(self, x):
>>> return self.allgather(x)
@ -184,11 +186,12 @@ class ReduceScatter(PrimitiveWithInfer):
Examples:
>>> from mindspore.communication.management import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.reducescatter = ReduceScatter(ReduceOp.SUM, group="nccl_world_group")
>>> self.reducescatter = P.ReduceScatter(ReduceOp.SUM, group="nccl_world_group")
>>>
>>> def construct(self, x):
>>> return self.reducescatter(x)
@ -246,11 +249,12 @@ class Broadcast(PrimitiveWithInfer):
Examples:
>>> from mindspore.communication.management import init
>>> import mindspore.ops.operations as P
>>> init('nccl')
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.broadcast = Broadcast(1)
>>> self.broadcast = P.Broadcast(1)
>>>
>>> def construct(self, x):
>>> return self.broadcast((x,))

View File

@ -150,7 +150,6 @@ class Merge(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, inputs):
"""merge select one input as its output"""
return (inputs[0], [1])
def infer_dtype(self, inputs):

View File

@ -1263,7 +1263,6 @@ class EqualCount(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, w_shape):
"""Infer shape."""
output_shape = (1,)
return output_shape

View File

@ -1310,6 +1310,9 @@ class SGD(PrimitiveWithInfer):
Nesterov momentum is based on the formula from On the importance of
initialization and momentum in deep learning.
Note:
For details, please refer to `nn.SGD` source code.
Args:
dampening (float): The dampening for momentum. Default: 0.0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
@ -1321,16 +1324,10 @@ class SGD(PrimitiveWithInfer):
- **learning_rate** (Tensor) - Learning rate. e.g. Tensor(0.1, mindspore.float32).
- **accum** (Tensor) - Accum(velocity) to be update.
- **momentum** (Tensor) - Momentum. e.g. Tensor(0.1, mindspore.float32).
- **stat** (Tensor) - States to be updated with the same shape as gradient. Default: 1.0.
- **stat** (Tensor) - States to be updated with the same shape as gradient.
Outputs:
Tensor, parameters to be update.
Examples:
>>> net = ResNet50()
>>> loss = SoftmaxCrossEntropyWithLogits()
>>> opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=0.9)
>>> model = Model(net, loss, opt)
"""
@prim_attr_register
@ -1768,7 +1765,7 @@ class LSTM(PrimitiveWithInfer):
"""
Performs the long short term memory(LSTM) on the input.
Detailed information, please refer to `nn.layer.LSTM`.
Detailed information, please refer to `nn.LSTM`.
"""
@prim_attr_register

View File

@ -91,13 +91,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
Default: None.
optimizer (Optimizer): Optimizer to update the Parameter.
level (str): Supports [O0, O2].
level (str): Supports [O0, O2]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
Default: "O0"
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.