parallel api fix

This commit is contained in:
yao_yf 2022-04-12 09:44:58 +08:00
parent 0a1913b370
commit 4a298ff1d5
2 changed files with 3 additions and 2 deletions

View File

@ -114,7 +114,7 @@
.. note::
- 仅支持 `Graph` 模式。
- 建议使用(cell.recompute(parallel_optimizer_comm_recompute=True/False)去配置由优化器并行生成的 :class:`mindspore.ops.AllGather` 算子,而不是直接使用该接口。
- 建议使用cell.recompute(parallel_optimizer_comm_recompute=True/False)去配置由优化器并行生成的 :class:`mindspore.ops.AllGather` 算子,而不是直接使用该接口。
.. py:method:: requires_grad
:property:

View File

@ -508,7 +508,8 @@ class AdaSumByDeltaWeightWrapCell(Cell):
>>> from mindspore import nn
>>> from mindspore.nn import AdaSumByDeltaWeightWrapCell
>>> net = Net()
>>> optim = AdaSumByGradWrapCell(nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9))
>>> optim = AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(),
... learning_rate=0.1, momentum=0.9))
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""