parallel api fix
This commit is contained in:
parent
0a1913b370
commit
4a298ff1d5
|
@ -114,7 +114,7 @@
|
|||
|
||||
.. note::
|
||||
- 仅支持 `Graph` 模式。
|
||||
- 建议使用(cell.recompute(parallel_optimizer_comm_recompute=True/False)去配置由优化器并行生成的 :class:`mindspore.ops.AllGather` 算子,而不是直接使用该接口。
|
||||
- 建议使用cell.recompute(parallel_optimizer_comm_recompute=True/False)去配置由优化器并行生成的 :class:`mindspore.ops.AllGather` 算子,而不是直接使用该接口。
|
||||
|
||||
.. py:method:: requires_grad
|
||||
:property:
|
||||
|
|
|
@ -508,7 +508,8 @@ class AdaSumByDeltaWeightWrapCell(Cell):
|
|||
>>> from mindspore import nn
|
||||
>>> from mindspore.nn import AdaSumByDeltaWeightWrapCell
|
||||
>>> net = Net()
|
||||
>>> optim = AdaSumByGradWrapCell(nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9))
|
||||
>>> optim = AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(),
|
||||
... learning_rate=0.1, momentum=0.9))
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue