!48850 add set_train note and opt docs

Merge pull request !48850 from changzherui/code_docs_api
This commit is contained in:
i-robot 2023-02-14 02:51:55 +00:00 committed by Gitee
commit 13b0afac51
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 12 additions and 4 deletions

View File

@ -512,6 +512,10 @@
设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 `BatchNorm`)将通过这个属性区分分支。如果设置为True则执行训练分支否则执行另一个分支。
.. note::
当执行Model.train()的时候框架会默认调用Cell.set_train(True)。
当执行Model.eval()的时候框架会默认调用Cell.set_train(False)。
参数:
- **mode** (bool) - 指定模型是否为训练模式。默认值True。

View File

@ -35,7 +35,7 @@
- **use_locking** (bool) - 是否对参数更新增加锁保护。默认值False。
输入:
- **var** (Tensor) - 要更新的权重。
- **var** (Parameter) - 要更新的权重。
- **mean_gradient** (Tensor) - 均值梯度,数据类型必须与 `var` 相同。
- **mean_square** (Tensor) - 均方梯度,数据类型必须与 `var` 相同。
- **moment** (Tensor) - `var` 的增量,数据类型必须与 `var` 相同。

View File

@ -28,7 +28,7 @@ mindspore.ops.ApplyRMSProp
- **use_locking** (bool) - 是否对参数更新加锁保护。默认值: False。
输入:
- **var** (Tensor) - 待更新的网络参数。
- **var** (Parameter) - 待更新的网络参数。
- **mean_square** (Tensor) - 均方梯度,数据类型需与 `var` 相同。
- **moment** (Tensor) - 一阶矩,数据类型需与 `var` 相同。
- **learning_rate** (Union[Number, Tensor]) - 学习率。需为浮点数或者数据类型为float16或float32的标量矩阵。

View File

@ -1604,6 +1604,10 @@ class Cell(Cell_):
for training and predicting, such as `BatchNorm`, will distinguish between the branches by this attribute. If
set to true, the training branch will be executed, otherwise another branch.
Note:
When execute function Model.train(), framework will call Cell.set_train(True).
When execute function Model.eval(), framework will call Cell.set_train(False).
Args:
mode (bool): Specifies whether the model is training. Default: True.

View File

@ -3313,7 +3313,7 @@ class ApplyRMSProp(PrimitiveWithInfer):
from being updated. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated.
- **var** (Parameter) - Weights to be updated.
- **mean_square** (Tensor) - Mean square gradients, must be the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must be the same type as `var`.
- **learning_rate** (Union[Number, Tensor]) - Learning rate. Must be a float number or
@ -3407,7 +3407,7 @@ class ApplyCenteredRMSProp(Primitive):
from being updated. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated.
- **var** (Parameter) - Weights to be updated.
- **mean_gradient** (Tensor) - Mean gradients, must be the same type as `var`.
- **mean_square** (Tensor) - Mean square gradients, must be the same type as `var`.
- **moment** (Tensor) - Delta of `var`, must be the same type as `var`.