!18963 Front-end annotation correction

Merge pull request !18963 from wangnan39/fix_docs
This commit is contained in:
i-robot 2021-06-28 10:33:37 +00:00 committed by Gitee
commit 3550dc2918
10 changed files with 209 additions and 75 deletions

View File

@ -13,13 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/getnext.h"
#include <set>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "ops/getnext.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"

View File

@ -62,10 +62,12 @@ class Cell(Cell_):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> class MyCell(nn.Cell):
... def __init__(self):
... super(MyCell, self).__init__()
... self.relu = P.ReLU()
... self.relu = ops.ReLU()
...
... def construct(self, x):
... return self.relu(x)
@ -607,7 +609,7 @@ class Cell(Cell_):
Compiles cell.
Args:
inputs (tuple): Input parameters.
inputs (tuple): Inputs of the Cell object.
"""
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
@ -616,7 +618,7 @@ class Cell(Cell_):
Compiles and runs cell.
Args:
inputs (tuple): Input parameters.
inputs (tuple): Inputs of the Cell object.
Returns:
Object, the result of executing.
@ -682,8 +684,13 @@ class Cell(Cell_):
"""
Cast parameter according to auto mix precision level in pynative mode.
This interface is currently used in the case of auto mix precision and usually need not to be used explicitly.
Args:
param (Parameter): The parameter to cast.
param (Parameter): Parameters, the type of which should be cast.
Returns:
Parameter, the input parameter with type automatically casted.
"""
if hasattr(self, "_mindspore_flags"):
if self._mindspore_flags.get('fp32'):
@ -725,7 +732,11 @@ class Cell(Cell_):
return None
def remove_redundant_parameters(self):
"""Remove the redundant parameters"""
"""
Remove the redundant parameters.
This interface usually need not to be used explicitly.
"""
cells = self.cells_and_names()
for _, cell in cells:
params = cell._params.items()
@ -836,7 +847,7 @@ class Cell(Cell_):
Adds the given prefix to the names of parameters.
Args:
prefix (str): The prefix string.
prefix (str): The prefix string. Default: ''.
recurse (bool): Whether contains the parameters of subcells. Default: True.
"""
@ -884,6 +895,9 @@ class Cell(Cell_):
expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
that are direct members of this cell. Default: True.
Returns:
Iteration, all parameters at the Cell.
Examples:
>>> net = Net()
>>> parameters = []
@ -905,13 +919,16 @@ class Cell(Cell_):
"""
Returns an iterator over cell parameters.
Includes the parameter's name and itself.
Includes the parameter's name and itself.
Args:
name_prefix (str): Namespace. Default: ''.
expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
that are direct members of this cell. Default: True.
Returns:
Iteration, all the names and corresponding parameters in the cell.
Examples:
>>> n = Net()
>>> names = []
@ -949,6 +966,9 @@ class Cell(Cell_):
cells (str): Cells to iterate over. Default: None.
name_prefix (str): Namespace. Default: ''.
Returns:
Iteration, all the child cells and corresponding names in the cell.
Examples:
>>> n = Net()
>>> names = []
@ -972,7 +992,12 @@ class Cell(Cell_):
yield ele
def cells(self):
"""Returns an iterator over immediate cells."""
"""
Returns an iterator over immediate cells.
Returns:
Iteration, all the child cells in the cell.
"""
return self.name_cells().values()
def _set_scope(self, name):
@ -997,7 +1022,12 @@ class Cell(Cell_):
yield key, value
def get_scope(self):
"""Returns the scope of a cell object in one network."""
"""
Returns the scope of a cell object in one network.
Returns:
String, scope of the cell.
"""
return self._scope
def generate_scope(self):
@ -1010,6 +1040,9 @@ class Cell(Cell_):
Returns an iterator over all cells in the network.
Include name of the cell and cell itself.
Returns:
Dict[String, Cell], all the child cells and corresponding names in the cell.
"""
value_set = set()
cells = OrderedDict()
@ -1056,6 +1089,9 @@ class Cell(Cell_):
dst_type (:class:`mindspore.dtype`): Transfer Cell to Run with dst_type.
dst_type can be `mindspore.dtype.float16` or `mindspore.dtype.float32`.
Returns:
Cell, the cell itself.
Raises:
ValueError: If dst_type is not float32 nor float16.
"""
@ -1080,6 +1116,9 @@ class Cell(Cell_):
Args:
acc_type (str): accelerate algorithm.
Returns:
Cell, the cell itself.
Raises:
ValueError: If acc_type is not in the algorithm library.
"""
@ -1098,6 +1137,9 @@ class Cell(Cell_):
Args:
requires_grad (bool): Specifies if the net need to grad, if it is
True, cell will construct backward network in pynative mode. Default: True.
Returns:
Cell, the cell itself.
"""
self.requires_grad = requires_grad
return self
@ -1112,6 +1154,9 @@ class Cell(Cell_):
Args:
mode (bool): Specifies whether the model is training. Default: True.
Returns:
Cell, the cell itself.
"""
if mode is False:
self._phase = 'predict'

View File

@ -37,7 +37,10 @@ class LearningRateSchedule(Cell):
The output must be a Tensor of scalar.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Inputs:
Tensor. Learning rate at current step with shape :math:`()`.
"""
raise NotImplementedError
@ -77,10 +80,10 @@ class ExponentialDecayLR(LearningRateSchedule):
is_stair (bool): If true, learning rate is decayed once every `decay_steps` time. Default: False.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
@ -144,10 +147,10 @@ class NaturalExpDecayLR(LearningRateSchedule):
is_stair (bool): If true, learning rate is decayed once every `decay_steps` time. Default: False.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
@ -212,10 +215,10 @@ class InverseDecayLR(LearningRateSchedule):
is_stair (bool): If true, learning rate decay once every `decay_steps` times. Default: False.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `learning_rate` or `decay_rate` is not a float.
@ -269,10 +272,10 @@ class CosineDecayLR(LearningRateSchedule):
decay_steps (int): A value used to calculate decayed learning rate.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `min_lr` or `max_lr` is not a float.
@ -345,10 +348,10 @@ class PolynomialDecayLR(LearningRateSchedule):
update_decay_steps (bool): If true, learning rate is decayed once every `decay_steps` time. Default: False.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `learning_rate`, `end_learning_rate` or `power` is not a float.
@ -424,10 +427,10 @@ class WarmUpLR(LearningRateSchedule):
warmup_steps (int): The warm up steps of learning rate.
Inputs:
Tensor. The current step number.
- **global_step** (Tensor) - The current step number.
Outputs:
Tensor. The learning rate value for the current step.
Tensor. The learning rate value for the current step with shape :math:`()`.
Raises:
TypeError: If `learning_rate` is not a float.

View File

@ -195,17 +195,16 @@ class Adam(Optimizer):
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \eps}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
:math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
@ -371,9 +370,29 @@ class Adam(Optimizer):
class AdamWeightDecay(Optimizer):
"""
r"""
Implements the Adam algorithm to fix the weight decay.
.. math::
\begin{array}{ll} \\
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
update = \frac{m_{t+1}}{\sqrt{v_{t+1}} + eps} \\
update =
\begin{cases}
update + \weight\_decay * w_{t}
& \text{ if } \weight\_decay > 0 \\
\update
& \text{ otherwise }
\end{cases} \\
w_{t+1} = w_{t} - lr * update
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`lr` represents `learning_rate`,
:math:`\beta_1, \beta_2` represent `beta1` and `beta2`, :math:`t` represents updating step while
:math:`w` represents `params`.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
@ -493,17 +512,16 @@ class AdamOffload(Optimizer):
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \eps}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
:math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`.
Note:
This optimizer only supports `GRAPH_MODE` currently.

View File

@ -114,17 +114,16 @@ class LazyAdam(Optimizer):
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
w = w - l * \frac{m}{\sqrt{v} + \epsilon}
w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \eps}
\end{array}
:math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
:math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
:math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
`beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
:math:`\epsilon` represents `eps`.
`beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the

View File

@ -52,13 +52,25 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None):
class ProximalAdagrad(Optimizer):
"""
r"""
Implements the ProximalAdagrad algorithm with ApplyProximalAdagrad Operator.
ProximalAdagrad is an online Learning and Stochastic Optimization.
Refer to paper `Efficient Learning using Forward-Backward Splitting
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
.. math::
accum_{t+1} = accum_{t} + grad * grad
.. math::
\text{prox_v} = var_{t} - lr * grad * \frac{1}{\sqrt{accum_{t+1}}}
.. math::
var_{t+1} = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Here : where grad, lr, var, accum and t denote the gradients, learning_rate, params and accumulation and current
step respectively.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied

View File

@ -66,6 +66,7 @@ def _tensors_cast_datatype(datatype, param):
return F.cast(param, datatype)
class WithLossCell(Cell):
r"""
Cell with loss function.
@ -82,7 +83,7 @@ class WithLossCell(Cell):
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar tensor with shape :math:`()`.
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
Raises:
TypeError: If dtype of `data` or `label` is neither float16 nor float32.
@ -114,7 +115,7 @@ class WithLossCell(Cell):
@property
def backbone_network(self):
"""
Returns the backbone network.
Get the backbone network.
Returns:
Cell, the backbone network.
@ -298,7 +299,7 @@ class TrainOneStepCell(Cell):
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
Tensor, a tensor means the loss value, the shape of which is usually :math:`()`.
Raises:
TypeError: If `sens` is not a number.
@ -408,6 +409,12 @@ class GetNextSingleOp(Cell):
For detailed information, refer to `ops.operations.GetNext`.
Inputs:
No inputs.
Outputs:
tuple[Tensor], the data get from Dataset.
Supported Platforms:
``Ascend`` ``GPU``
@ -635,13 +642,19 @@ class WithEvalCell(Cell):
class ParameterUpdate(Cell):
"""
Cell that updates parameters.
Cell that updates parameter.
With this Cell, one can manually update `param` with the input `Tensor`.
Args:
param (Parameter): The parameter to be updated manually.
Inputs:
- **x** (Tensor) - A tensor whose shape and type are the same with `param`.
Outputs:
Tensor, the input `x`.
Raises:
KeyError: If parameter with the specified name does not exist.

View File

@ -72,11 +72,11 @@ class DynamicLossScaleUpdateCell(Cell):
scale_window (int): Maximum continuous training steps that do not have overflow.
Inputs:
- **inputs** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`.
- **overflow** (bool) - Whether the overflow occurs or not.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
bool, the input `overflow`.
Raises:
TypeError: If dtype of `inputs` or `label` is neither float16 nor float32.
@ -165,6 +165,13 @@ class FixedLossScaleUpdateCell(Cell):
Args:
loss_scale_value (float): Initializes loss scale.
Inputs:
- **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`, that will be ignored.
- **overflow** (bool) - Whether the overflow occurs or not.
Outputs:
bool, the input `overflow`.
Supported Platforms:
``Ascend`` ``GPU``
@ -332,7 +339,11 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
def set_sense_scale(self, sens):
"""
If the user has set the sens in the training process and wants to reassign the value, he can call
this function again to make modification, and sens needs to be of type Tensor."""
this function again to make modification, and sens needs to be of type Tensor.
Inputs:
- **sens**(Tensor) - The new sense whose shape and type are the same with original `scale_sense`.
"""
if self.scale_sense and isinstance(sens, Tensor):
self.scale_sense.set_data(sens)
else:
@ -347,15 +358,15 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss
function, and compute_input should be the input of gradients-computing function.
Args:
pre_cond(object): A precondition for starting overflow detection. It determines the executing order of
overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears
status after finishing the process of precondition.
compute_input(object): The input of subsequent process. Overflow detection should be performed on a certain
computation. Set `compute_input` as the input of the computation, to ensure overflow status is cleared
before executing the computation.
Inputs:
- **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order
of overflow state clearing and prior processions. It makes sure that the function 'start_overflow'
clears status after finishing the process of precondition.
- **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a
certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is
cleared before executing the computation.
Returns:
Outputs:
Tuple[object, object], the first value is False for GPU backend, while it is a instance of
NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection.
The second value is the same as the input of `compute_input`, but contains some information about the
@ -377,12 +388,13 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
Get overflow results after executing the target process for overflow detection.
Args:
status(object): A status instance used to detect the overflow.
compute_output: Overflow detection should be performed on a certain computation. Set `compute_output` as
the output of the computation, to ensure overflow status is acquired before executing the computation.
Inputs:
- **status** (object) - A status instance used to detect the overflow.
- **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output`
as the output of the computation, to ensure overflow status is acquired before executing the
computation.
Returns:
Outputs:
bool, whether the overflow occurs or not.
"""
if not self.gpu_target:
@ -409,10 +421,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
"""
Calculate loss scale according to the overflow.
Args:
overflow(bool): Whether the overflow occurs or not.
Inputs:
- **overflow** (bool) - Whether the overflow occurs or not.
Returns:
Outputs:
bool, overflow value.
"""
if self.loss_scaling_manager is not None:

View File

@ -266,7 +266,7 @@ class DatasetHelper:
self.iter.release()
def continue_send(self):
"""continue send data to device at the beginning of epoch."""
"""Continue send data to device at the beginning of epoch."""
self.iter.continue_send()
def get_data_info(self):

View File

@ -67,23 +67,39 @@ class FixedLossScaleManager(LossScaleManager):
self._drop_overflow_update = drop_overflow_update
def get_loss_scale(self):
"""Get loss scale value."""
"""
Get loss scale value.
Returns:
bool, `loss_scale` value.
"""
return self._loss_scale
def get_drop_overflow_update(self):
"""Get the flag whether to drop optimizer update when there is an overflow."""
"""
Get the flag whether to drop optimizer update when there is an overflow.
Returns:
bool, `drop_overflow_update` value.
"""
return self._drop_overflow_update
def update_loss_scale(self, overflow):
"""
Update loss scale value.
Update loss scale value. The interface at `FixedLossScaleManager` will do nothing.
Args:
overflow (bool): Whether it overflows.
"""
def get_update_cell(self):
"Returns the cell for `TrainOneStepWithLossScaleCell`"
"""
Returns the update cell for `TrainOneStepWithLossScaleCell`.
Returns:
None or Cell. Cell object, used to update `loss_scale`, when `drop_overflow_update` is True. None when
`drop_overflow_update` is False.
"""
if not self._drop_overflow_update:
return None
return nn.FixedLossScaleUpdateCell(self._loss_scale)
@ -127,7 +143,12 @@ class DynamicLossScaleManager(LossScaleManager):
self.bad_step = 0
def get_loss_scale(self):
"""Get loss scale value."""
"""
Get loss scale value.
Returns:
bool, `loss_scale` value.
"""
return self.loss_scale
def update_loss_scale(self, overflow):
@ -152,9 +173,19 @@ class DynamicLossScaleManager(LossScaleManager):
self.cur_iter += 1
def get_drop_overflow_update(self):
"""Get the flag whether to drop optimizer update when there is an overflow."""
"""
Get the flag whether to drop optimizer update when there is an overflow.
Returns:
bool, always return True at `DynamicLossScaleManager`.
"""
return True
def get_update_cell(self):
"Returns the cell for `TrainOneStepWithLossScaleCell`"
"""
Returns the update cell for `TrainOneStepWithLossScaleCell`.
Returns:
Cell, cell object used to update `loss_scale`.
"""
return nn.DynamicLossScaleUpdateCell(self.loss_scale, self.scale_factor, self.scale_window)