forked from mindspore-Ecosystem/mindspore
!48850 add set_train note and opt docs
Merge pull request !48850 from changzherui/code_docs_api
This commit is contained in:
commit
13b0afac51
|
@ -512,6 +512,10 @@
|
|||
|
||||
设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。
|
||||
|
||||
.. note::
|
||||
当执行Model.train()的时候,框架会默认调用Cell.set_train(True)。
|
||||
当执行Model.eval()的时候,框架会默认调用Cell.set_train(False)。
|
||||
|
||||
参数:
|
||||
- **mode** (bool) - 指定模型是否为训练模式。默认值:True。
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
- **use_locking** (bool) - 是否对参数更新增加锁保护。默认值:False。
|
||||
|
||||
输入:
|
||||
- **var** (Tensor) - 要更新的权重。
|
||||
- **var** (Parameter) - 要更新的权重。
|
||||
- **mean_gradient** (Tensor) - 均值梯度,数据类型必须与 `var` 相同。
|
||||
- **mean_square** (Tensor) - 均方梯度,数据类型必须与 `var` 相同。
|
||||
- **moment** (Tensor) - `var` 的增量,数据类型必须与 `var` 相同。
|
||||
|
|
|
@ -28,7 +28,7 @@ mindspore.ops.ApplyRMSProp
|
|||
- **use_locking** (bool) - 是否对参数更新加锁保护。默认值: False。
|
||||
|
||||
输入:
|
||||
- **var** (Tensor) - 待更新的网络参数。
|
||||
- **var** (Parameter) - 待更新的网络参数。
|
||||
- **mean_square** (Tensor) - 均方梯度,数据类型需与 `var` 相同。
|
||||
- **moment** (Tensor) - 一阶矩,数据类型需与 `var` 相同。
|
||||
- **learning_rate** (Union[Number, Tensor]) - 学习率。需为浮点数或者数据类型为float16或float32的标量矩阵。
|
||||
|
|
|
@ -1604,6 +1604,10 @@ class Cell(Cell_):
|
|||
for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
|
||||
set to true, the training branch will be executed, otherwise another branch.
|
||||
|
||||
Note:
|
||||
When execute function Model.train(), framework will call Cell.set_train(True).
|
||||
When execute function Model.eval(), framework will call Cell.set_train(False).
|
||||
|
||||
Args:
|
||||
mode (bool): Specifies whether the model is training. Default: True.
|
||||
|
||||
|
|
|
@ -3313,7 +3313,7 @@ class ApplyRMSProp(PrimitiveWithInfer):
|
|||
from being updated. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Weights to be updated.
|
||||
- **var** (Parameter) - Weights to be updated.
|
||||
- **mean_square** (Tensor) - Mean square gradients, must be the same type as `var`.
|
||||
- **moment** (Tensor) - Delta of `var`, must be the same type as `var`.
|
||||
- **learning_rate** (Union[Number, Tensor]) - Learning rate. Must be a float number or
|
||||
|
@ -3407,7 +3407,7 @@ class ApplyCenteredRMSProp(Primitive):
|
|||
from being updated. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Tensor) - Weights to be updated.
|
||||
- **var** (Parameter) - Weights to be updated.
|
||||
- **mean_gradient** (Tensor) - Mean gradients, must be the same type as `var`.
|
||||
- **mean_square** (Tensor) - Mean square gradients, must be the same type as `var`.
|
||||
- **moment** (Tensor) - Delta of `var`, must be the same type as `var`.
|
||||
|
|
Loading…
Reference in New Issue