!18749 adjust api format of thor
Merge pull request !18749 from wangshuangling/master
This commit is contained in:
commit
923b6a48c5
|
@ -209,12 +209,12 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
A_i = a_i{a_i}^T \\
|
||||
G_i = D_{s_i}{ D_{s_i}}^T \\
|
||||
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
|
||||
w_i = w_i - \alpha * m_i \\
|
||||
\end{array}
|
||||
\begin{array}{ll} \\
|
||||
A_i = a_i{a_i}^T \\
|
||||
G_i = D_{s_i}{ D_{s_i}}^T \\
|
||||
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
|
||||
w_i = w_i - \alpha * m_i \\
|
||||
\end{array}
|
||||
|
||||
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer,
|
||||
:math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer,
|
||||
|
|
|
@ -167,12 +167,14 @@ class ConvertModelUtils():
|
|||
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
|
||||
training. eg: {'accuracy', 'recall'}. Default: None.
|
||||
amp_level (str): Level for mixed precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
|
||||
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||
- auto: Set level to recommended level in different devices. O2 is recommended on GPU, O3 is
|
||||
recommended on Ascend. The recommended level is based on the expert experience, cannot
|
||||
always generalize. User should specify the level for special network.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled.
|
||||
Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
|
|
Loading…
Reference in New Issue